From 2a0baa80b0642bce8f293d3f1dbd96a0f1bb3718 Mon Sep 17 00:00:00 2001 From: "Amin.MasterkinG" Date: Mon, 20 Apr 2026 19:24:38 +0330 Subject: [PATCH] Server side. --- client.toml | 20 ++ client.go => cmd/client/client.go | 5 +- cmd/server/server.go | 42 ++++ internal/client/sender_workers.go | 33 +-- internal/config/config.go | 139 +++++++---- internal/protocol/packet.go | 62 +++++ internal/server/server.go | 371 ++++++++++++++++++++++++++++++ server.toml | 15 ++ 8 files changed, 614 insertions(+), 73 deletions(-) create mode 100644 client.toml rename client.go => cmd/client/client.go (84%) create mode 100644 cmd/server/server.go create mode 100644 internal/server/server.go create mode 100644 server.toml diff --git a/client.toml b/client.toml new file mode 100644 index 0000000..3baf1a8 --- /dev/null +++ b/client.toml @@ -0,0 +1,20 @@ +# ============================================================================== +AES_ENCRYPTION_KEY = "c4710a45afed2fdc00e0522c70802e71" +RELAY_URL = "http://127.0.0.1/relay.php" +# ============================================================================== +LOG_LEVEL = "INFO" +# ============================================================================== +SOCKS_HOST = "127.0.0.1" +SOCKS_PORT = 18001 +SOCKS_AUTH = false +SOCKS_USERNAME = "your_socks_username_here" +SOCKS_PASSWORD = "your_socks_password_here" +# ============================================================================== +MAX_CHUNK_SIZE = 16384 +MAX_PACKETS_PER_BATCH = 32 +MAX_BATCH_BYTES = 262144 +WORKER_COUNT = 4 +HTTP_REQUEST_TIMEOUT_MS = 15000 +WORKER_POLL_INTERVAL_MS = 200 +MAX_QUEUE_BYTES_PER_SOCKS = 1048576 +# ============================================================================== diff --git a/client.go b/cmd/client/client.go similarity index 84% rename from client.go rename to cmd/client/client.go index ba2d508..59579a6 100644 --- a/client.go +++ b/cmd/client/client.go @@ -21,10 +21,13 @@ import ( func main() { logger := lg.New("MasterHttpRelayVPN Client", "INFO") - cfg, err := config.Load("config.toml") + cfg, err := config.Load("client.toml") if err != nil { logger.Fatalf("load config: %v", err) } + if err := cfg.ValidateClient(); err != nil { + logger.Fatalf("validate client config: %v", err) + } logger = lg.New("MasterHttpRelayVPN Client", cfg.LogLevel) diff --git a/cmd/server/server.go b/cmd/server/server.go new file mode 100644 index 0000000..3f60d7a --- /dev/null +++ b/cmd/server/server.go @@ -0,0 +1,42 @@ +// ============================================================================== +// MasterHttpRelayVPN +// Author: MasterkinG32 +// Github: https://github.com/masterking32 +// Year: 2026 +// ============================================================================== + +package main + +import ( + "context" + "os" + "os/signal" + "syscall" + + "masterhttprelayvpn/internal/config" + lg "masterhttprelayvpn/internal/logger" + "masterhttprelayvpn/internal/server" +) + +func main() { + logger := lg.New("MasterHttpRelayVPN Server", "INFO") + + cfg, err := config.Load("server.toml") + if err != nil { + logger.Fatalf("load config: %v", err) + } + if err := cfg.ValidateServer(); err != nil { + logger.Fatalf("validate server config: %v", err) + } + + logger = lg.New("MasterHttpRelayVPN Server", cfg.LogLevel) + + app := server.New(cfg, logger) + + ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer cancel() + + if err := app.Run(ctx); err != nil { + logger.Fatalf("run server: %v", err) + } +} diff --git a/internal/client/sender_workers.go b/internal/client/sender_workers.go index d6f478f..5d0ac87 100644 --- a/internal/client/sender_workers.go +++ b/internal/client/sender_workers.go @@ -9,11 +9,6 @@ package client import ( "bytes" "context" - "crypto/aes" - "crypto/cipher" - "crypto/rand" - "crypto/sha256" - "encoding/json" "fmt" "io" "net/http" @@ -73,7 +68,7 @@ func (w *sendWorker) run(ctx context.Context, c *Client) { continue } - body, err := encryptBatch(batch, c.cfg.AESEncryptionKey) + body, err := protocol.EncryptBatch(batch, c.cfg.AESEncryptionKey) if err != nil { c.log.Errorf("worker=%d encrypt batch failed: %v", w.id, err) c.requeueSelected(selected) @@ -176,29 +171,3 @@ func (w *sendWorker) postBatch(ctx context.Context, c *Client, batch protocol.Ba _, _ = io.Copy(io.Discard, resp.Body) return nil } - -func encryptBatch(batch protocol.Batch, keyText string) ([]byte, error) { - plain, err := json.Marshal(batch) - if err != nil { - return nil, err - } - - key := sha256.Sum256([]byte(keyText)) - block, err := aes.NewCipher(key[:]) - if err != nil { - return nil, err - } - - gcm, err := cipher.NewGCM(block) - if err != nil { - return nil, err - } - - nonce := make([]byte, gcm.NonceSize()) - if _, err := rand.Read(nonce); err != nil { - return nil, err - } - - ciphertext := gcm.Seal(nil, nonce, plain, nil) - return append(nonce, ciphertext...), nil -} diff --git a/internal/config/config.go b/internal/config/config.go index e5db491..02bbd3b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -18,6 +18,8 @@ import ( type Config struct { AESEncryptionKey string RelayURL string + ServerHost string + ServerPort int SOCKSHost string SOCKSPort int SOCKSAuth bool @@ -31,12 +33,17 @@ type Config struct { HTTPRequestTimeoutMS int WorkerPollIntervalMS int MaxQueueBytesPerSOCKS int + SessionIdleTimeoutMS int + SOCKSIdleTimeoutMS int + ReadBodyLimitBytes int } func Load(path string) (Config, error) { cfg := Config{ SOCKSHost: "127.0.0.1", SOCKSPort: 1080, + ServerHost: "127.0.0.1", + ServerPort: 28080, LogLevel: "INFO", MaxChunkSize: 16 * 1024, MaxPacketsPerBatch: 32, @@ -45,6 +52,9 @@ func Load(path string) (Config, error) { HTTPRequestTimeoutMS: 15000, WorkerPollIntervalMS: 200, MaxQueueBytesPerSOCKS: 1024 * 1024, + SessionIdleTimeoutMS: 5 * 60 * 1000, + SOCKSIdleTimeoutMS: 2 * 60 * 1000, + ReadBodyLimitBytes: 2 * 1024 * 1024, } file, err := os.Open(path) @@ -73,6 +83,14 @@ func Load(path string) (Config, error) { cfg.AESEncryptionKey = trimString(value) case "RELAY_URL": cfg.RelayURL = trimString(value) + case "SERVER_HOST": + cfg.ServerHost = trimString(value) + case "SERVER_PORT": + port, err := strconv.Atoi(value) + if err != nil { + return Config{}, fmt.Errorf("parse SERVER_PORT: %w", err) + } + cfg.ServerPort = port case "SOCKS_HOST": cfg.SOCKSHost = trimString(value) case "SOCKS_PORT": @@ -135,6 +153,24 @@ func Load(path string) (Config, error) { return Config{}, fmt.Errorf("parse MAX_QUEUE_BYTES_PER_SOCKS: %w", err) } cfg.MaxQueueBytesPerSOCKS = size + case "SESSION_IDLE_TIMEOUT_MS": + timeout, err := strconv.Atoi(value) + if err != nil { + return Config{}, fmt.Errorf("parse SESSION_IDLE_TIMEOUT_MS: %w", err) + } + cfg.SessionIdleTimeoutMS = timeout + case "SOCKS_IDLE_TIMEOUT_MS": + timeout, err := strconv.Atoi(value) + if err != nil { + return Config{}, fmt.Errorf("parse SOCKS_IDLE_TIMEOUT_MS: %w", err) + } + cfg.SOCKSIdleTimeoutMS = timeout + case "READ_BODY_LIMIT_BYTES": + size, err := strconv.Atoi(value) + if err != nil { + return Config{}, fmt.Errorf("parse READ_BODY_LIMIT_BYTES: %w", err) + } + cfg.ReadBodyLimitBytes = size } } @@ -142,49 +178,72 @@ func Load(path string) (Config, error) { return Config{}, err } - if cfg.SOCKSAuth && (cfg.SOCKSUsername == "" || cfg.SOCKSPassword == "") { - return Config{}, fmt.Errorf("SOCKS auth enabled but username/password missing") - } - - if cfg.SOCKSPort < 1 || cfg.SOCKSPort > 65535 { - return Config{}, fmt.Errorf("invalid SOCKS_PORT: %d", cfg.SOCKSPort) - } - if strings.TrimSpace(cfg.RelayURL) == "" { - return Config{}, fmt.Errorf("RELAY_URL is required") - } - if strings.TrimSpace(cfg.AESEncryptionKey) == "" { - return Config{}, fmt.Errorf("AES_ENCRYPTION_KEY is required") - } - - if cfg.MaxChunkSize < 1 { - return Config{}, fmt.Errorf("invalid MAX_CHUNK_SIZE: %d", cfg.MaxChunkSize) - } - - if cfg.MaxPacketsPerBatch < 1 { - return Config{}, fmt.Errorf("invalid MAX_PACKETS_PER_BATCH: %d", cfg.MaxPacketsPerBatch) - } - - if cfg.MaxBatchBytes < cfg.MaxChunkSize { - return Config{}, fmt.Errorf("MAX_BATCH_BYTES must be >= MAX_CHUNK_SIZE") - } - - if cfg.WorkerCount < 1 { - return Config{}, fmt.Errorf("invalid WORKER_COUNT: %d", cfg.WorkerCount) - } - if cfg.HTTPRequestTimeoutMS < 1 { - return Config{}, fmt.Errorf("invalid HTTP_REQUEST_TIMEOUT_MS: %d", cfg.HTTPRequestTimeoutMS) - } - if cfg.WorkerPollIntervalMS < 1 { - return Config{}, fmt.Errorf("invalid WORKER_POLL_INTERVAL_MS: %d", cfg.WorkerPollIntervalMS) - } - - if cfg.MaxQueueBytesPerSOCKS < cfg.MaxChunkSize { - return Config{}, fmt.Errorf("MAX_QUEUE_BYTES_PER_SOCKS must be >= MAX_CHUNK_SIZE") - } - return cfg, nil } +func (c Config) ValidateClient() error { + if err := c.validateShared(); err != nil { + return err + } + if c.SOCKSAuth && (c.SOCKSUsername == "" || c.SOCKSPassword == "") { + return fmt.Errorf("SOCKS auth enabled but username/password missing") + } + if c.SOCKSPort < 1 || c.SOCKSPort > 65535 { + return fmt.Errorf("invalid SOCKS_PORT: %d", c.SOCKSPort) + } + if strings.TrimSpace(c.RelayURL) == "" { + return fmt.Errorf("RELAY_URL is required") + } + if c.HTTPRequestTimeoutMS < 1 { + return fmt.Errorf("invalid HTTP_REQUEST_TIMEOUT_MS: %d", c.HTTPRequestTimeoutMS) + } + if c.WorkerPollIntervalMS < 1 { + return fmt.Errorf("invalid WORKER_POLL_INTERVAL_MS: %d", c.WorkerPollIntervalMS) + } + if c.MaxQueueBytesPerSOCKS < c.MaxChunkSize { + return fmt.Errorf("MAX_QUEUE_BYTES_PER_SOCKS must be >= MAX_CHUNK_SIZE") + } + return nil +} + +func (c Config) ValidateServer() error { + if err := c.validateShared(); err != nil { + return err + } + if c.ServerPort < 1 || c.ServerPort > 65535 { + return fmt.Errorf("invalid SERVER_PORT: %d", c.ServerPort) + } + if c.SessionIdleTimeoutMS < 1 { + return fmt.Errorf("invalid SESSION_IDLE_TIMEOUT_MS: %d", c.SessionIdleTimeoutMS) + } + if c.SOCKSIdleTimeoutMS < 1 { + return fmt.Errorf("invalid SOCKS_IDLE_TIMEOUT_MS: %d", c.SOCKSIdleTimeoutMS) + } + if c.ReadBodyLimitBytes < c.MaxChunkSize { + return fmt.Errorf("READ_BODY_LIMIT_BYTES must be >= MAX_CHUNK_SIZE") + } + return nil +} + +func (c Config) validateShared() error { + if strings.TrimSpace(c.AESEncryptionKey) == "" { + return fmt.Errorf("AES_ENCRYPTION_KEY is required") + } + if c.MaxChunkSize < 1 { + return fmt.Errorf("invalid MAX_CHUNK_SIZE: %d", c.MaxChunkSize) + } + if c.MaxPacketsPerBatch < 1 { + return fmt.Errorf("invalid MAX_PACKETS_PER_BATCH: %d", c.MaxPacketsPerBatch) + } + if c.MaxBatchBytes < c.MaxChunkSize { + return fmt.Errorf("MAX_BATCH_BYTES must be >= MAX_CHUNK_SIZE") + } + if c.WorkerCount < 1 { + return fmt.Errorf("invalid WORKER_COUNT: %d", c.WorkerCount) + } + return nil +} + func trimString(value string) string { return strings.Trim(value, `"`) } diff --git a/internal/protocol/packet.go b/internal/protocol/packet.go index e23d90e..f87121c 100644 --- a/internal/protocol/packet.go +++ b/internal/protocol/packet.go @@ -7,8 +7,12 @@ package protocol import ( + "crypto/aes" + "crypto/cipher" "crypto/rand" + "crypto/sha256" "encoding/hex" + "encoding/json" "errors" "fmt" "net" @@ -216,3 +220,61 @@ func (b Batch) Validate() error { return nil } + +func EncryptBatch(batch Batch, keyText string) ([]byte, error) { + plain, err := json.Marshal(batch) + if err != nil { + return nil, err + } + + key := sha256.Sum256([]byte(keyText)) + block, err := aes.NewCipher(key[:]) + if err != nil { + return nil, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return nil, err + } + + nonce := make([]byte, gcm.NonceSize()) + if _, err := rand.Read(nonce); err != nil { + return nil, err + } + + ciphertext := gcm.Seal(nil, nonce, plain, nil) + return append(nonce, ciphertext...), nil +} + +func DecryptBatch(ciphertext []byte, keyText string) (Batch, error) { + key := sha256.Sum256([]byte(keyText)) + block, err := aes.NewCipher(key[:]) + if err != nil { + return Batch{}, err + } + + gcm, err := cipher.NewGCM(block) + if err != nil { + return Batch{}, err + } + if len(ciphertext) < gcm.NonceSize() { + return Batch{}, fmt.Errorf("encrypted body is shorter than nonce size") + } + + nonce := ciphertext[:gcm.NonceSize()] + encrypted := ciphertext[gcm.NonceSize():] + plain, err := gcm.Open(nil, nonce, encrypted, nil) + if err != nil { + return Batch{}, err + } + + var batch Batch + if err := json.Unmarshal(plain, &batch); err != nil { + return Batch{}, err + } + if err := batch.Validate(); err != nil { + return Batch{}, err + } + return batch, nil +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..0fde6af --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,371 @@ +// ============================================================================== +// MasterHttpRelayVPN +// Author: MasterkinG32 +// Github: https://github.com/masterking32 +// Year: 2026 +// ============================================================================== +package server + +import ( + "context" + "fmt" + "io" + "net" + "net/http" + "sync" + "time" + + "masterhttprelayvpn/internal/config" + "masterhttprelayvpn/internal/logger" + "masterhttprelayvpn/internal/protocol" +) + +type Server struct { + cfg config.Config + log *logger.Logger + + mu sync.RWMutex + sessions map[string]*ClientSession +} + +type ClientSession struct { + ClientSessionKey string + CreatedAt time.Time + LastActivityAt time.Time + SOCKSConnections map[uint64]*SOCKSState +} + +type SOCKSState struct { + ID uint64 + CreatedAt time.Time + LastActivityAt time.Time + Target *protocol.Target + ConnectSeen bool + ConnectAcked bool + CloseReadSeen bool + CloseWriteSeen bool + ResetSeen bool + ReceivedBytes uint64 + LastSequenceSeen uint64 +} + +func New(cfg config.Config, lg *logger.Logger) *Server { + return &Server{ + cfg: cfg, + log: lg, + sessions: make(map[string]*ClientSession), + } +} + +func (s *Server) Run(ctx context.Context) error { + addr := fmt.Sprintf("%s:%d", s.cfg.ServerHost, s.cfg.ServerPort) + + mux := http.NewServeMux() + mux.HandleFunc("/", s.handleRelay) + mux.HandleFunc("/relay", s.handleRelay) + + httpServer := &http.Server{ + Addr: addr, + Handler: mux, + ReadHeaderTimeout: 10 * time.Second, + IdleTimeout: 60 * time.Second, + } + + go s.cleanupLoop(ctx) + go func() { + <-ctx.Done() + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _ = httpServer.Shutdown(shutdownCtx) + }() + + s.log.Infof("server listening on %s", addr) + if err := httpServer.ListenAndServe(); err != nil && err != http.ErrServerClosed { + return err + } + return nil +} + +func (s *Server) handleRelay(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + r.Body = http.MaxBytesReader(w, r.Body, int64(s.cfg.ReadBodyLimitBytes)) + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, "read body failed", http.StatusBadRequest) + return + } + + batch, err := protocol.DecryptBatch(body, s.cfg.AESEncryptionKey) + if err != nil { + s.log.Warnf("decrypt batch failed: %v", err) + http.Error(w, "invalid encrypted payload", http.StatusBadRequest) + return + } + + responseBatch, err := s.processBatch(batch) + if err != nil { + s.log.Warnf("process batch=%s failed: %v", batch.BatchID, err) + http.Error(w, "batch processing failed", http.StatusBadRequest) + return + } + + if len(responseBatch.Packets) == 0 { + w.WriteHeader(http.StatusNoContent) + return + } + + encrypted, err := protocol.EncryptBatch(responseBatch, s.cfg.AESEncryptionKey) + if err != nil { + s.log.Errorf("encrypt response batch failed: %v", err) + http.Error(w, "response encryption failed", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("X-Relay-Version", fmt.Sprintf("%d", protocol.CurrentVersion)) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(encrypted) +} + +func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) { + session := s.getOrCreateSession(batch.ClientSessionKey) + now := time.Now() + + s.mu.Lock() + session.LastActivityAt = now + + responses := make([]protocol.Packet, 0, len(batch.Packets)) + for _, packet := range batch.Packets { + response, err := s.processPacketLocked(session, packet, now) + if err != nil { + s.mu.Unlock() + return protocol.Batch{}, err + } + if response != nil { + responses = append(responses, *response) + } + } + s.mu.Unlock() + + if len(responses) == 0 { + return protocol.Batch{}, nil + } + return protocol.NewBatch(batch.ClientSessionKey, protocol.NewBatchID(), responses), nil +} + +func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Packet, now time.Time) (*protocol.Packet, error) { + if packet.ClientSessionKey != session.ClientSessionKey { + return nil, fmt.Errorf("packet client session key mismatch") + } + + switch packet.Type { + case protocol.PacketTypeSOCKSConnect: + if packet.Target == nil { + return nil, fmt.Errorf("socks_connect missing target") + } + + socksState, exists := session.SOCKSConnections[packet.SOCKSID] + if !exists { + socksState = &SOCKSState{ + ID: packet.SOCKSID, + CreatedAt: now, + LastActivityAt: now, + Target: packet.Target, + ConnectSeen: true, + ConnectAcked: true, + LastSequenceSeen: packet.Sequence, + } + session.SOCKSConnections[packet.SOCKSID] = socksState + } else { + socksState.LastActivityAt = now + socksState.Target = packet.Target + socksState.ConnectSeen = true + socksState.ConnectAcked = true + if packet.Sequence > socksState.LastSequenceSeen { + socksState.LastSequenceSeen = packet.Sequence + } + } + + response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSConnectAck) + response.SOCKSID = packet.SOCKSID + response.Sequence = packet.Sequence + return &response, nil + + case protocol.PacketTypeSOCKSData: + socksState := s.getOrCreateSOCKSStateLocked(session, packet, now) + socksState.LastActivityAt = now + socksState.ReceivedBytes += uint64(len(packet.Payload)) + if packet.Sequence > socksState.LastSequenceSeen { + socksState.LastSequenceSeen = packet.Sequence + } + + response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSDataAck) + response.SOCKSID = packet.SOCKSID + response.Sequence = packet.Sequence + response.FragmentID = packet.FragmentID + response.TotalFragments = packet.TotalFragments + response.Final = packet.Final + return &response, nil + + case protocol.PacketTypeSOCKSCloseRead: + socksState := s.getOrCreateSOCKSStateLocked(session, packet, now) + socksState.LastActivityAt = now + socksState.CloseReadSeen = true + if packet.Sequence > socksState.LastSequenceSeen { + socksState.LastSequenceSeen = packet.Sequence + } + response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSCloseRead) + response.SOCKSID = packet.SOCKSID + response.Sequence = packet.Sequence + return &response, nil + + case protocol.PacketTypeSOCKSCloseWrite: + socksState := s.getOrCreateSOCKSStateLocked(session, packet, now) + socksState.LastActivityAt = now + socksState.CloseWriteSeen = true + if packet.Sequence > socksState.LastSequenceSeen { + socksState.LastSequenceSeen = packet.Sequence + } + response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSCloseWrite) + response.SOCKSID = packet.SOCKSID + response.Sequence = packet.Sequence + return &response, nil + + case protocol.PacketTypeSOCKSRST: + socksState := s.getOrCreateSOCKSStateLocked(session, packet, now) + socksState.LastActivityAt = now + socksState.ResetSeen = true + if packet.Sequence > socksState.LastSequenceSeen { + socksState.LastSequenceSeen = packet.Sequence + } + response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST) + response.SOCKSID = packet.SOCKSID + response.Sequence = packet.Sequence + delete(session.SOCKSConnections, packet.SOCKSID) + return &response, nil + + case protocol.PacketTypePing: + response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypePong) + response.Payload = append([]byte(nil), packet.Payload...) + return &response, nil + + case protocol.PacketTypeSOCKSConnectAck, + protocol.PacketTypeSOCKSConnectFail, + protocol.PacketTypeSOCKSRuleSetDenied, + protocol.PacketTypeSOCKSNetworkUnreachable, + protocol.PacketTypeSOCKSHostUnreachable, + protocol.PacketTypeSOCKSConnectionRefused, + protocol.PacketTypeSOCKSTTLExpired, + protocol.PacketTypeSOCKSCommandUnsupported, + protocol.PacketTypeSOCKSAddressTypeUnsupported, + protocol.PacketTypeSOCKSAuthFailed, + protocol.PacketTypeSOCKSUpstreamUnavailable, + protocol.PacketTypeSOCKSDataAck, + protocol.PacketTypePong: + return nil, nil + + default: + return nil, fmt.Errorf("unsupported packet type: %s", packet.Type) + } +} + +func (s *Server) getOrCreateSession(clientSessionKey string) *ClientSession { + s.mu.RLock() + existing := s.sessions[clientSessionKey] + s.mu.RUnlock() + if existing != nil { + return existing + } + + s.mu.Lock() + defer s.mu.Unlock() + + existing = s.sessions[clientSessionKey] + if existing != nil { + return existing + } + + now := time.Now() + session := &ClientSession{ + ClientSessionKey: clientSessionKey, + CreatedAt: now, + LastActivityAt: now, + SOCKSConnections: make(map[uint64]*SOCKSState), + } + s.sessions[clientSessionKey] = session + s.log.Infof("created client session %s", clientSessionKey) + return session +} + +func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet protocol.Packet, now time.Time) *SOCKSState { + socksState := session.SOCKSConnections[packet.SOCKSID] + if socksState != nil { + return socksState + } + + socksState = &SOCKSState{ + ID: packet.SOCKSID, + CreatedAt: now, + LastActivityAt: now, + Target: packet.Target, + LastSequenceSeen: packet.Sequence, + } + session.SOCKSConnections[packet.SOCKSID] = socksState + return socksState +} + +func (s *Server) cleanupLoop(ctx context.Context) { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + s.cleanupExpired() + } + } +} + +func (s *Server) cleanupExpired() { + sessionTTL := time.Duration(s.cfg.SessionIdleTimeoutMS) * time.Millisecond + socksTTL := time.Duration(s.cfg.SOCKSIdleTimeoutMS) * time.Millisecond + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + for clientSessionKey, session := range s.sessions { + for socksID, socksState := range session.SOCKSConnections { + if now.Sub(socksState.LastActivityAt) > socksTTL { + delete(session.SOCKSConnections, socksID) + s.log.Debugf("expired socks state session=%s socks_id=%d", clientSessionKey, socksID) + } + } + + if len(session.SOCKSConnections) == 0 && now.Sub(session.LastActivityAt) > sessionTTL { + delete(s.sessions, clientSessionKey) + s.log.Infof("expired client session %s", clientSessionKey) + } + } +} + +func (s *Server) SessionSnapshot() (sessions int, socksConnections int) { + s.mu.RLock() + defer s.mu.RUnlock() + + sessions = len(s.sessions) + for _, session := range s.sessions { + socksConnections += len(session.SOCKSConnections) + } + return sessions, socksConnections +} + +func LocalListenAddress(host string, port int) string { + return net.JoinHostPort(host, fmt.Sprintf("%d", port)) +} diff --git a/server.toml b/server.toml new file mode 100644 index 0000000..647443e --- /dev/null +++ b/server.toml @@ -0,0 +1,15 @@ +# ============================================================================== +AES_ENCRYPTION_KEY = "c4710a45afed2fdc00e0522c70802e71" +SERVER_HOST = "127.0.0.1" +SERVER_PORT = 28080 +# ============================================================================== +LOG_LEVEL = "INFO" +# ============================================================================== +MAX_CHUNK_SIZE = 16384 +MAX_PACKETS_PER_BATCH = 32 +MAX_BATCH_BYTES = 262144 +WORKER_COUNT = 4 +SESSION_IDLE_TIMEOUT_MS = 300000 +SOCKS_IDLE_TIMEOUT_MS = 120000 +READ_BODY_LIMIT_BYTES = 2097152 +# ==============================================================================