diff --git a/internal/server/server.go b/internal/server/server.go
index 0af0fce..d41f437 100644
--- a/internal/server/server.go
+++ b/internal/server/server.go
@@ -26,11 +26,13 @@ type Server struct {
cfg config.Config
log *logger.Logger
- mu sync.RWMutex
- sessions map[string]*ClientSession
+ mu sync.RWMutex
+ sessions map[string]*ClientSession
+ dialUpstream func(network string, address string, timeout time.Duration) (net.Conn, error)
}
type ClientSession struct {
+ mu sync.Mutex
ClientSessionKey string
CreatedAt time.Time
LastActivityAt time.Time
@@ -71,9 +73,10 @@ type PendingInboundPacket struct {
func New(cfg config.Config, lg *logger.Logger) *Server {
return &Server{
- cfg: cfg,
- log: lg,
- sessions: make(map[string]*ClientSession),
+ cfg: cfg,
+ log: lg,
+ sessions: make(map[string]*ClientSession),
+ dialUpstream: net.DialTimeout,
}
}
@@ -171,7 +174,9 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) {
session := s.getOrCreateSession(batch.ClientSessionKey)
now := time.Now()
- s.mu.Lock()
+ session.mu.Lock()
+ defer session.mu.Unlock()
+
session.LastActivityAt = now
responses := make([]protocol.Packet, 0, len(batch.Packets))
@@ -183,7 +188,6 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) {
)
orderedResponses, err := s.processPacketLocked(session, packet, now)
if err != nil {
- s.mu.Unlock()
return protocol.Batch{}, err
}
for _, response := range orderedResponses {
@@ -201,7 +205,6 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) {
)
responses = append(responses, outbound)
}
- s.mu.Unlock()
if len(responses) == 0 {
return protocol.Batch{}, nil
@@ -250,7 +253,8 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
"dial upstream socks_id=%d target=%s client_session_key=%s",
packet.SOCKSID, packet.Target.Address(), session.ClientSessionKey,
)
- upstreamConn, err := net.DialTimeout("tcp", packet.Target.Address(), 10*time.Second)
+ upstreamConn, err := s.dial("tcp", packet.Target.Address(), 10*time.Second)
+
if err != nil {
s.log.Warnf(
"upstream dial failed socks_id=%d target=%s client_session_key=%s error=%v",
@@ -262,6 +266,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
response.Payload = []byte(err.Error())
return []protocol.Packet{response}, nil
}
+
socksState.UpstreamConn = upstreamConn
socksState.ConnectAcked = true
s.log.Infof(
@@ -887,10 +892,16 @@ func (s *Server) cleanupExpired() {
socksTTL := time.Duration(s.cfg.SOCKSIdleTimeoutMS) * time.Millisecond
now := time.Now()
- s.mu.Lock()
- defer s.mu.Unlock()
-
+ s.mu.RLock()
+ sessionSnapshots := make(map[string]*ClientSession, len(s.sessions))
for clientSessionKey, session := range s.sessions {
+ sessionSnapshots[clientSessionKey] = session
+ }
+ s.mu.RUnlock()
+
+ sessionsToDelete := make([]string, 0)
+ for clientSessionKey, session := range sessionSnapshots {
+ session.mu.Lock()
for socksID, socksState := range session.SOCKSConnections {
if now.Sub(socksState.LastActivityAt) > socksTTL {
targetAddress := targetAddressForLog(socksState.Target)
@@ -902,23 +913,58 @@ func (s *Server) cleanupExpired() {
}
if len(session.SOCKSConnections) == 0 && now.Sub(session.LastActivityAt) > sessionTTL {
- delete(s.sessions, clientSessionKey)
- s.log.Infof("expired client session %s", clientSessionKey)
+ sessionsToDelete = append(sessionsToDelete, clientSessionKey)
}
+ session.mu.Unlock()
+ }
+
+ if len(sessionsToDelete) == 0 {
+ return
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for _, clientSessionKey := range sessionsToDelete {
+ session := s.sessions[clientSessionKey]
+ if session == nil {
+ continue
+ }
+ session.mu.Lock()
+ expired := len(session.SOCKSConnections) == 0 && now.Sub(session.LastActivityAt) > sessionTTL
+ session.mu.Unlock()
+ if !expired {
+ continue
+ }
+ delete(s.sessions, clientSessionKey)
+ s.log.Infof("expired client session %s", clientSessionKey)
}
}
func (s *Server) SessionSnapshot() (sessions int, socksConnections int) {
s.mu.RLock()
- defer s.mu.RUnlock()
-
- sessions = len(s.sessions)
+ sessionList := make([]*ClientSession, 0, len(s.sessions))
for _, session := range s.sessions {
- socksConnections += len(session.SOCKSConnections)
+ sessionList = append(sessionList, session)
}
+ sessions = len(sessionList)
+ s.mu.RUnlock()
+
+ for _, session := range sessionList {
+ session.mu.Lock()
+ socksConnections += len(session.SOCKSConnections)
+ session.mu.Unlock()
+ }
+
return sessions, socksConnections
}
+func (s *Server) dial(network string, address string, timeout time.Duration) (net.Conn, error) {
+ if s.dialUpstream == nil {
+ return net.DialTimeout(network, address, timeout)
+ }
+ return s.dialUpstream(network, address, timeout)
+}
+
func LocalListenAddress(host string, port int) string {
return net.JoinHostPort(host, fmt.Sprintf("%d", port))
}
diff --git a/internal/server/server_test.go b/internal/server/server_test.go
index b30d60c..8aa767d 100644
--- a/internal/server/server_test.go
+++ b/internal/server/server_test.go
@@ -1,10 +1,13 @@
package server
import (
+ "errors"
+ "net"
"testing"
"time"
"masterhttprelayvpn/internal/config"
+ "masterhttprelayvpn/internal/logger"
"masterhttprelayvpn/internal/protocol"
)
@@ -189,6 +192,70 @@ func TestSOCKSStateReleaseClearsQueueState(t *testing.T) {
}
}
+func TestProcessBatchBlockedSessionDoesNotBlockOtherSessions(t *testing.T) {
+ dialStarted := make(chan struct{})
+ releaseDial := make(chan struct{})
+
+ srv := New(config.Config{
+ MaxPacketsPerBatch: 8,
+ MaxBatchBytes: 1024,
+ MaxReorderBufferPackets: 8,
+ MaxServerQueueBytes: 1024,
+ }, logger.New("server-test", "ERROR"))
+ srv.dialUpstream = func(network string, address string, timeout time.Duration) (net.Conn, error) {
+ if address != "slow.example:80" {
+ return nil, errors.New("unexpected dial target")
+ }
+ close(dialStarted)
+ <-releaseDial
+ return nil, errors.New("forced dial failure")
+ }
+
+ connect := protocol.NewPacket("session-a", protocol.PacketTypeSOCKSConnect)
+ connect.SOCKSID = 1
+ connect.Sequence = 0
+ connect.Target = &protocol.Target{Host: "slow.example", Port: 80}
+
+ errCh := make(chan error, 1)
+ go func() {
+ _, err := srv.processBatch(protocol.NewBatch("session-a", "batch-a", []protocol.Packet{connect}))
+ errCh <- err
+ }()
+
+ select {
+ case <-dialStarted:
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("expected slow session dial to start")
+ }
+
+ ping := protocol.NewPacket("session-b", protocol.PacketTypePing)
+ done := make(chan error, 1)
+ go func() {
+ _, err := srv.processBatch(protocol.NewBatch("session-b", "batch-b", []protocol.Packet{ping}))
+ done <- err
+ }()
+
+ select {
+ case err := <-done:
+ if err != nil {
+ t.Fatalf("expected unrelated session batch to complete, got error: %v", err)
+ }
+ case <-time.After(300 * time.Millisecond):
+ t.Fatal("expected unrelated session batch to complete while first session dial is blocked")
+ }
+
+ close(releaseDial)
+
+ select {
+ case err := <-errCh:
+ if err != nil {
+ t.Fatalf("expected blocked session batch to convert dial failure into response, got error: %v", err)
+ }
+ case <-time.After(500 * time.Millisecond):
+ t.Fatal("expected blocked session batch to finish after releasing dial")
+ }
+}
+
func testDataPacket(clientSessionKey string, socksID uint64, sequence uint64, payload string) protocol.Packet {
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData)
packet.SOCKSID = socksID