Make server outbound sequencing and upstream connection lifecycle concurrency-safe

This commit is contained in:
Amin.MasterkinG
2026-04-21 12:29:19 +03:30
parent 720ab14a44
commit 22f13fb234
2 changed files with 132 additions and 28 deletions
+69 -28
View File
@@ -15,6 +15,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"masterhttprelayvpn/internal/config" "masterhttprelayvpn/internal/config"
@@ -56,8 +57,9 @@ type SOCKSState struct {
OutboundSequence uint64 OutboundSequence uint64
PendingInbound map[uint64][]PendingInboundPacket PendingInbound map[uint64][]PendingInboundPacket
UpstreamConn net.Conn UpstreamConn net.Conn
activityMu sync.RWMutex
upstreamStateMu sync.RWMutex
upstreamWriteMu sync.Mutex upstreamWriteMu sync.Mutex
upstreamCloseMu sync.Mutex
upstreamReadEOF bool upstreamReadEOF bool
upstreamWriteEOF bool upstreamWriteEOF bool
queueMu sync.Mutex queueMu sync.Mutex
@@ -237,7 +239,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
} }
session.SOCKSConnections[packet.SOCKSID] = socksState session.SOCKSConnections[packet.SOCKSID] = socksState
} else { } else {
socksState.LastActivityAt = now socksState.setLastActivityAt(now)
socksState.Target = packet.Target socksState.Target = packet.Target
socksState.ConnectSeen = true socksState.ConnectSeen = true
if socksState.PendingInbound == nil { if socksState.PendingInbound == nil {
@@ -248,7 +250,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
} }
} }
if socksState.UpstreamConn == nil { if _, connected := socksState.currentUpstreamConn(); !connected {
s.log.Debugf( s.log.Debugf(
"<gray>dial upstream socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></gray>", "<gray>dial upstream socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Target.Address(), session.ClientSessionKey, packet.SOCKSID, packet.Target.Address(), session.ClientSessionKey,
@@ -267,7 +269,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
return []protocol.Packet{response}, nil return []protocol.Packet{response}, nil
} }
socksState.UpstreamConn = upstreamConn socksState.setUpstreamConn(upstreamConn)
socksState.ConnectAcked = true socksState.ConnectAcked = true
s.log.Infof( s.log.Infof(
"<green>upstream connected socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></green>", "<green>upstream connected socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></green>",
@@ -408,7 +410,7 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot
} }
func (s *Server) applyOrderedPacketLocked(session *ClientSession, socksState *SOCKSState, packet protocol.Packet, now time.Time) ([]protocol.Packet, error) { func (s *Server) applyOrderedPacketLocked(session *ClientSession, socksState *SOCKSState, packet protocol.Packet, now time.Time) ([]protocol.Packet, error) {
socksState.LastActivityAt = now socksState.setLastActivityAt(now)
if packet.Sequence > socksState.LastSequenceSeen { if packet.Sequence > socksState.LastSequenceSeen {
socksState.LastSequenceSeen = packet.Sequence socksState.LastSequenceSeen = packet.Sequence
} }
@@ -559,13 +561,14 @@ func (s *Server) drainSessionOutboundLocked(session *ClientSession) []protocol.P
} }
func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSState) { func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSState) {
if socksState.UpstreamConn == nil {
return
}
buffer := make([]byte, s.cfg.MaxChunkSize) buffer := make([]byte, s.cfg.MaxChunkSize)
for { for {
n, err := socksState.UpstreamConn.Read(buffer) upstreamConn, ok := socksState.currentUpstreamConn()
if !ok {
return
}
n, err := upstreamConn.Read(buffer)
if n > 0 { if n > 0 {
chunk := append([]byte(nil), buffer[:n]...) chunk := append([]byte(nil), buffer[:n]...)
if !socksState.enqueueOutboundData(clientSessionKey, chunk, false) { if !socksState.enqueueOutboundData(clientSessionKey, chunk, false) {
@@ -577,7 +580,7 @@ func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSStat
_ = socksState.closeUpstream() _ = socksState.closeUpstream()
return return
} }
socksState.LastActivityAt = time.Now() socksState.setLastActivityAt(time.Now())
queueDepth, queueBytes := socksState.queueSnapshot() queueDepth, queueBytes := socksState.queueSnapshot()
s.log.Debugf( s.log.Debugf(
"<gray>read upstream socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> bytes=<cyan>%d</cyan> queue_depth=<cyan>%d</cyan> queue_bytes=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>", "<gray>read upstream socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> bytes=<cyan>%d</cyan> queue_depth=<cyan>%d</cyan> queue_bytes=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
@@ -614,8 +617,7 @@ func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSStat
} }
func (s *SOCKSState) nextOutboundSequence() uint64 { func (s *SOCKSState) nextOutboundSequence() uint64 {
s.OutboundSequence++ return atomic.AddUint64(&s.OutboundSequence, 1)
return s.OutboundSequence
} }
func (s *SOCKSState) expectedInboundSequenceLocked() uint64 { func (s *SOCKSState) expectedInboundSequenceLocked() uint64 {
@@ -821,58 +823,97 @@ func (s *SOCKSState) writeUpstream(payload []byte) error {
} }
s.upstreamWriteMu.Lock() s.upstreamWriteMu.Lock()
defer s.upstreamWriteMu.Unlock() defer s.upstreamWriteMu.Unlock()
if s.UpstreamConn == nil { upstreamConn, ok := s.currentUpstreamConn()
if !ok {
return fmt.Errorf("upstream connection is not established") return fmt.Errorf("upstream connection is not established")
} }
_, err := s.UpstreamConn.Write(payload) _, err := upstreamConn.Write(payload)
return err return err
} }
func (s *SOCKSState) closeUpstream() error { func (s *SOCKSState) closeUpstream() error {
s.upstreamCloseMu.Lock() s.upstreamStateMu.Lock()
defer s.upstreamCloseMu.Unlock() upstreamConn := s.UpstreamConn
if s.UpstreamConn == nil { if upstreamConn == nil {
s.upstreamStateMu.Unlock()
return nil return nil
} }
err := s.UpstreamConn.Close()
s.UpstreamConn = nil s.UpstreamConn = nil
s.upstreamReadEOF = true s.upstreamReadEOF = true
s.upstreamWriteEOF = true s.upstreamWriteEOF = true
s.upstreamStateMu.Unlock()
err := upstreamConn.Close()
return err return err
} }
func (s *SOCKSState) closeUpstreamRead() error { func (s *SOCKSState) closeUpstreamRead() error {
s.upstreamCloseMu.Lock() s.upstreamStateMu.Lock()
defer s.upstreamCloseMu.Unlock()
if s.upstreamReadEOF { if s.upstreamReadEOF {
s.upstreamStateMu.Unlock()
return nil return nil
} }
s.upstreamReadEOF = true s.upstreamReadEOF = true
if tcpConn, ok := s.UpstreamConn.(*net.TCPConn); ok { upstreamConn := s.UpstreamConn
s.upstreamStateMu.Unlock()
if tcpConn, ok := upstreamConn.(*net.TCPConn); ok {
return tcpConn.CloseRead() return tcpConn.CloseRead()
} }
return nil return nil
} }
func (s *SOCKSState) closeUpstreamWrite() error { func (s *SOCKSState) closeUpstreamWrite() error {
s.upstreamCloseMu.Lock() s.upstreamStateMu.Lock()
defer s.upstreamCloseMu.Unlock()
if s.upstreamWriteEOF { if s.upstreamWriteEOF {
s.upstreamStateMu.Unlock()
return nil return nil
} }
s.upstreamWriteEOF = true s.upstreamWriteEOF = true
if tcpConn, ok := s.UpstreamConn.(*net.TCPConn); ok { upstreamConn := s.UpstreamConn
s.upstreamStateMu.Unlock()
if tcpConn, ok := upstreamConn.(*net.TCPConn); ok {
return tcpConn.CloseWrite() return tcpConn.CloseWrite()
} }
if s.UpstreamConn == nil { if upstreamConn == nil {
return nil return nil
} }
err := s.UpstreamConn.Close()
err := upstreamConn.Close()
s.upstreamStateMu.Lock()
s.UpstreamConn = nil s.UpstreamConn = nil
s.upstreamReadEOF = true s.upstreamReadEOF = true
s.upstreamStateMu.Unlock()
return err return err
} }
func (s *SOCKSState) setUpstreamConn(conn net.Conn) {
s.upstreamStateMu.Lock()
defer s.upstreamStateMu.Unlock()
s.UpstreamConn = conn
s.upstreamReadEOF = false
s.upstreamWriteEOF = false
}
func (s *SOCKSState) currentUpstreamConn() (net.Conn, bool) {
s.upstreamStateMu.RLock()
defer s.upstreamStateMu.RUnlock()
if s.UpstreamConn == nil {
return nil, false
}
return s.UpstreamConn, true
}
func (s *SOCKSState) setLastActivityAt(now time.Time) {
s.activityMu.Lock()
s.LastActivityAt = now
s.activityMu.Unlock()
}
func (s *SOCKSState) lastActivityAt() time.Time {
s.activityMu.RLock()
defer s.activityMu.RUnlock()
return s.LastActivityAt
}
func (s *Server) cleanupLoop(ctx context.Context) { func (s *Server) cleanupLoop(ctx context.Context) {
ticker := time.NewTicker(30 * time.Second) ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop() defer ticker.Stop()
@@ -903,7 +944,7 @@ func (s *Server) cleanupExpired() {
for clientSessionKey, session := range sessionSnapshots { for clientSessionKey, session := range sessionSnapshots {
session.mu.Lock() session.mu.Lock()
for socksID, socksState := range session.SOCKSConnections { for socksID, socksState := range session.SOCKSConnections {
if now.Sub(socksState.LastActivityAt) > socksTTL { if now.Sub(socksState.lastActivityAt()) > socksTTL {
targetAddress := targetAddressForLog(socksState.Target) targetAddress := targetAddressForLog(socksState.Target)
_ = socksState.closeUpstream() _ = socksState.closeUpstream()
socksState.release() socksState.release()
+63
View File
@@ -3,6 +3,7 @@ package server
import ( import (
"errors" "errors"
"net" "net"
"sync"
"testing" "testing"
"time" "time"
@@ -256,6 +257,68 @@ func TestProcessBatchBlockedSessionDoesNotBlockOtherSessions(t *testing.T) {
} }
} }
func TestSOCKSStateNextOutboundSequenceIsConcurrentSafe(t *testing.T) {
socksState := &SOCKSState{}
const workers = 32
const iterationsPerWorker = 64
start := make(chan struct{})
var wg sync.WaitGroup
results := make(chan uint64, workers*iterationsPerWorker)
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
for j := 0; j < iterationsPerWorker; j++ {
results <- socksState.nextOutboundSequence()
}
}()
}
close(start)
wg.Wait()
close(results)
seen := make(map[uint64]struct{}, workers*iterationsPerWorker)
var maxSeq uint64
for seq := range results {
if _, exists := seen[seq]; exists {
t.Fatalf("duplicate outbound sequence generated: %d", seq)
}
seen[seq] = struct{}{}
if seq > maxSeq {
maxSeq = seq
}
}
expected := workers * iterationsPerWorker
if len(seen) != expected {
t.Fatalf("expected %d unique sequences, got %d", expected, len(seen))
}
if maxSeq != uint64(expected) {
t.Fatalf("expected max sequence %d, got %d", expected, maxSeq)
}
}
func TestSOCKSStateCloseUpstreamClearsConnectionSnapshot(t *testing.T) {
serverConn, clientConn := net.Pipe()
defer clientConn.Close()
socksState := &SOCKSState{}
socksState.setUpstreamConn(serverConn)
if err := socksState.closeUpstream(); err != nil {
t.Fatalf("expected closeUpstream to succeed, got %v", err)
}
if _, ok := socksState.currentUpstreamConn(); ok {
t.Fatal("expected upstream connection snapshot to be cleared after close")
}
}
func testDataPacket(clientSessionKey string, socksID uint64, sequence uint64, payload string) protocol.Packet { func testDataPacket(clientSessionKey string, socksID uint64, sequence uint64, payload string) protocol.Packet {
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData) packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData)
packet.SOCKSID = socksID packet.SOCKSID = socksID