diff --git a/internal/server/server.go b/internal/server/server.go index 0af0fce..d41f437 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -26,11 +26,13 @@ type Server struct { cfg config.Config log *logger.Logger - mu sync.RWMutex - sessions map[string]*ClientSession + mu sync.RWMutex + sessions map[string]*ClientSession + dialUpstream func(network string, address string, timeout time.Duration) (net.Conn, error) } type ClientSession struct { + mu sync.Mutex ClientSessionKey string CreatedAt time.Time LastActivityAt time.Time @@ -71,9 +73,10 @@ type PendingInboundPacket struct { func New(cfg config.Config, lg *logger.Logger) *Server { return &Server{ - cfg: cfg, - log: lg, - sessions: make(map[string]*ClientSession), + cfg: cfg, + log: lg, + sessions: make(map[string]*ClientSession), + dialUpstream: net.DialTimeout, } } @@ -171,7 +174,9 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) { session := s.getOrCreateSession(batch.ClientSessionKey) now := time.Now() - s.mu.Lock() + session.mu.Lock() + defer session.mu.Unlock() + session.LastActivityAt = now responses := make([]protocol.Packet, 0, len(batch.Packets)) @@ -183,7 +188,6 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) { ) orderedResponses, err := s.processPacketLocked(session, packet, now) if err != nil { - s.mu.Unlock() return protocol.Batch{}, err } for _, response := range orderedResponses { @@ -201,7 +205,6 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) { ) responses = append(responses, outbound) } - s.mu.Unlock() if len(responses) == 0 { return protocol.Batch{}, nil @@ -250,7 +253,8 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac "dial upstream socks_id=%d target=%s client_session_key=%s", packet.SOCKSID, packet.Target.Address(), session.ClientSessionKey, ) - upstreamConn, err := net.DialTimeout("tcp", packet.Target.Address(), 10*time.Second) + upstreamConn, err := s.dial("tcp", packet.Target.Address(), 10*time.Second) + if err != nil { s.log.Warnf( "upstream dial failed socks_id=%d target=%s client_session_key=%s error=%v", @@ -262,6 +266,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac response.Payload = []byte(err.Error()) return []protocol.Packet{response}, nil } + socksState.UpstreamConn = upstreamConn socksState.ConnectAcked = true s.log.Infof( @@ -887,10 +892,16 @@ func (s *Server) cleanupExpired() { socksTTL := time.Duration(s.cfg.SOCKSIdleTimeoutMS) * time.Millisecond now := time.Now() - s.mu.Lock() - defer s.mu.Unlock() - + s.mu.RLock() + sessionSnapshots := make(map[string]*ClientSession, len(s.sessions)) for clientSessionKey, session := range s.sessions { + sessionSnapshots[clientSessionKey] = session + } + s.mu.RUnlock() + + sessionsToDelete := make([]string, 0) + for clientSessionKey, session := range sessionSnapshots { + session.mu.Lock() for socksID, socksState := range session.SOCKSConnections { if now.Sub(socksState.LastActivityAt) > socksTTL { targetAddress := targetAddressForLog(socksState.Target) @@ -902,23 +913,58 @@ func (s *Server) cleanupExpired() { } if len(session.SOCKSConnections) == 0 && now.Sub(session.LastActivityAt) > sessionTTL { - delete(s.sessions, clientSessionKey) - s.log.Infof("expired client session %s", clientSessionKey) + sessionsToDelete = append(sessionsToDelete, clientSessionKey) } + session.mu.Unlock() + } + + if len(sessionsToDelete) == 0 { + return + } + + s.mu.Lock() + defer s.mu.Unlock() + for _, clientSessionKey := range sessionsToDelete { + session := s.sessions[clientSessionKey] + if session == nil { + continue + } + session.mu.Lock() + expired := len(session.SOCKSConnections) == 0 && now.Sub(session.LastActivityAt) > sessionTTL + session.mu.Unlock() + if !expired { + continue + } + delete(s.sessions, clientSessionKey) + s.log.Infof("expired client session %s", clientSessionKey) } } func (s *Server) SessionSnapshot() (sessions int, socksConnections int) { s.mu.RLock() - defer s.mu.RUnlock() - - sessions = len(s.sessions) + sessionList := make([]*ClientSession, 0, len(s.sessions)) for _, session := range s.sessions { - socksConnections += len(session.SOCKSConnections) + sessionList = append(sessionList, session) } + sessions = len(sessionList) + s.mu.RUnlock() + + for _, session := range sessionList { + session.mu.Lock() + socksConnections += len(session.SOCKSConnections) + session.mu.Unlock() + } + return sessions, socksConnections } +func (s *Server) dial(network string, address string, timeout time.Duration) (net.Conn, error) { + if s.dialUpstream == nil { + return net.DialTimeout(network, address, timeout) + } + return s.dialUpstream(network, address, timeout) +} + func LocalListenAddress(host string, port int) string { return net.JoinHostPort(host, fmt.Sprintf("%d", port)) } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index b30d60c..8aa767d 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1,10 +1,13 @@ package server import ( + "errors" + "net" "testing" "time" "masterhttprelayvpn/internal/config" + "masterhttprelayvpn/internal/logger" "masterhttprelayvpn/internal/protocol" ) @@ -189,6 +192,70 @@ func TestSOCKSStateReleaseClearsQueueState(t *testing.T) { } } +func TestProcessBatchBlockedSessionDoesNotBlockOtherSessions(t *testing.T) { + dialStarted := make(chan struct{}) + releaseDial := make(chan struct{}) + + srv := New(config.Config{ + MaxPacketsPerBatch: 8, + MaxBatchBytes: 1024, + MaxReorderBufferPackets: 8, + MaxServerQueueBytes: 1024, + }, logger.New("server-test", "ERROR")) + srv.dialUpstream = func(network string, address string, timeout time.Duration) (net.Conn, error) { + if address != "slow.example:80" { + return nil, errors.New("unexpected dial target") + } + close(dialStarted) + <-releaseDial + return nil, errors.New("forced dial failure") + } + + connect := protocol.NewPacket("session-a", protocol.PacketTypeSOCKSConnect) + connect.SOCKSID = 1 + connect.Sequence = 0 + connect.Target = &protocol.Target{Host: "slow.example", Port: 80} + + errCh := make(chan error, 1) + go func() { + _, err := srv.processBatch(protocol.NewBatch("session-a", "batch-a", []protocol.Packet{connect})) + errCh <- err + }() + + select { + case <-dialStarted: + case <-time.After(500 * time.Millisecond): + t.Fatal("expected slow session dial to start") + } + + ping := protocol.NewPacket("session-b", protocol.PacketTypePing) + done := make(chan error, 1) + go func() { + _, err := srv.processBatch(protocol.NewBatch("session-b", "batch-b", []protocol.Packet{ping})) + done <- err + }() + + select { + case err := <-done: + if err != nil { + t.Fatalf("expected unrelated session batch to complete, got error: %v", err) + } + case <-time.After(300 * time.Millisecond): + t.Fatal("expected unrelated session batch to complete while first session dial is blocked") + } + + close(releaseDial) + + select { + case err := <-errCh: + if err != nil { + t.Fatalf("expected blocked session batch to convert dial failure into response, got error: %v", err) + } + case <-time.After(500 * time.Millisecond): + t.Fatal("expected blocked session batch to finish after releasing dial") + } +} + func testDataPacket(clientSessionKey string, socksID uint64, sequence uint64, payload string) protocol.Packet { packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData) packet.SOCKSID = socksID