Make server outbound sequencing and upstream connection lifecycle concurrency-safe

This commit is contained in:
Amin.MasterkinG
2026-04-21 12:29:19 +03:30
parent 720ab14a44
commit 22f13fb234
2 changed files with 132 additions and 28 deletions
+69 -28
View File
@@ -15,6 +15,7 @@ import (
"net/http"
"strings"
"sync"
"sync/atomic"
"time"
"masterhttprelayvpn/internal/config"
@@ -56,8 +57,9 @@ type SOCKSState struct {
OutboundSequence uint64
PendingInbound map[uint64][]PendingInboundPacket
UpstreamConn net.Conn
activityMu sync.RWMutex
upstreamStateMu sync.RWMutex
upstreamWriteMu sync.Mutex
upstreamCloseMu sync.Mutex
upstreamReadEOF bool
upstreamWriteEOF bool
queueMu sync.Mutex
@@ -237,7 +239,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
}
session.SOCKSConnections[packet.SOCKSID] = socksState
} else {
socksState.LastActivityAt = now
socksState.setLastActivityAt(now)
socksState.Target = packet.Target
socksState.ConnectSeen = true
if socksState.PendingInbound == nil {
@@ -248,7 +250,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
}
}
if socksState.UpstreamConn == nil {
if _, connected := socksState.currentUpstreamConn(); !connected {
s.log.Debugf(
"<gray>dial upstream socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Target.Address(), session.ClientSessionKey,
@@ -267,7 +269,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
return []protocol.Packet{response}, nil
}
socksState.UpstreamConn = upstreamConn
socksState.setUpstreamConn(upstreamConn)
socksState.ConnectAcked = true
s.log.Infof(
"<green>upstream connected socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></green>",
@@ -408,7 +410,7 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot
}
func (s *Server) applyOrderedPacketLocked(session *ClientSession, socksState *SOCKSState, packet protocol.Packet, now time.Time) ([]protocol.Packet, error) {
socksState.LastActivityAt = now
socksState.setLastActivityAt(now)
if packet.Sequence > socksState.LastSequenceSeen {
socksState.LastSequenceSeen = packet.Sequence
}
@@ -559,13 +561,14 @@ func (s *Server) drainSessionOutboundLocked(session *ClientSession) []protocol.P
}
func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSState) {
if socksState.UpstreamConn == nil {
return
}
buffer := make([]byte, s.cfg.MaxChunkSize)
for {
n, err := socksState.UpstreamConn.Read(buffer)
upstreamConn, ok := socksState.currentUpstreamConn()
if !ok {
return
}
n, err := upstreamConn.Read(buffer)
if n > 0 {
chunk := append([]byte(nil), buffer[:n]...)
if !socksState.enqueueOutboundData(clientSessionKey, chunk, false) {
@@ -577,7 +580,7 @@ func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSStat
_ = socksState.closeUpstream()
return
}
socksState.LastActivityAt = time.Now()
socksState.setLastActivityAt(time.Now())
queueDepth, queueBytes := socksState.queueSnapshot()
s.log.Debugf(
"<gray>read upstream socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> bytes=<cyan>%d</cyan> queue_depth=<cyan>%d</cyan> queue_bytes=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
@@ -614,8 +617,7 @@ func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSStat
}
func (s *SOCKSState) nextOutboundSequence() uint64 {
s.OutboundSequence++
return s.OutboundSequence
return atomic.AddUint64(&s.OutboundSequence, 1)
}
func (s *SOCKSState) expectedInboundSequenceLocked() uint64 {
@@ -821,58 +823,97 @@ func (s *SOCKSState) writeUpstream(payload []byte) error {
}
s.upstreamWriteMu.Lock()
defer s.upstreamWriteMu.Unlock()
if s.UpstreamConn == nil {
upstreamConn, ok := s.currentUpstreamConn()
if !ok {
return fmt.Errorf("upstream connection is not established")
}
_, err := s.UpstreamConn.Write(payload)
_, err := upstreamConn.Write(payload)
return err
}
func (s *SOCKSState) closeUpstream() error {
s.upstreamCloseMu.Lock()
defer s.upstreamCloseMu.Unlock()
if s.UpstreamConn == nil {
s.upstreamStateMu.Lock()
upstreamConn := s.UpstreamConn
if upstreamConn == nil {
s.upstreamStateMu.Unlock()
return nil
}
err := s.UpstreamConn.Close()
s.UpstreamConn = nil
s.upstreamReadEOF = true
s.upstreamWriteEOF = true
s.upstreamStateMu.Unlock()
err := upstreamConn.Close()
return err
}
func (s *SOCKSState) closeUpstreamRead() error {
s.upstreamCloseMu.Lock()
defer s.upstreamCloseMu.Unlock()
s.upstreamStateMu.Lock()
if s.upstreamReadEOF {
s.upstreamStateMu.Unlock()
return nil
}
s.upstreamReadEOF = true
if tcpConn, ok := s.UpstreamConn.(*net.TCPConn); ok {
upstreamConn := s.UpstreamConn
s.upstreamStateMu.Unlock()
if tcpConn, ok := upstreamConn.(*net.TCPConn); ok {
return tcpConn.CloseRead()
}
return nil
}
func (s *SOCKSState) closeUpstreamWrite() error {
s.upstreamCloseMu.Lock()
defer s.upstreamCloseMu.Unlock()
s.upstreamStateMu.Lock()
if s.upstreamWriteEOF {
s.upstreamStateMu.Unlock()
return nil
}
s.upstreamWriteEOF = true
if tcpConn, ok := s.UpstreamConn.(*net.TCPConn); ok {
upstreamConn := s.UpstreamConn
s.upstreamStateMu.Unlock()
if tcpConn, ok := upstreamConn.(*net.TCPConn); ok {
return tcpConn.CloseWrite()
}
if s.UpstreamConn == nil {
if upstreamConn == nil {
return nil
}
err := s.UpstreamConn.Close()
err := upstreamConn.Close()
s.upstreamStateMu.Lock()
s.UpstreamConn = nil
s.upstreamReadEOF = true
s.upstreamStateMu.Unlock()
return err
}
func (s *SOCKSState) setUpstreamConn(conn net.Conn) {
s.upstreamStateMu.Lock()
defer s.upstreamStateMu.Unlock()
s.UpstreamConn = conn
s.upstreamReadEOF = false
s.upstreamWriteEOF = false
}
func (s *SOCKSState) currentUpstreamConn() (net.Conn, bool) {
s.upstreamStateMu.RLock()
defer s.upstreamStateMu.RUnlock()
if s.UpstreamConn == nil {
return nil, false
}
return s.UpstreamConn, true
}
func (s *SOCKSState) setLastActivityAt(now time.Time) {
s.activityMu.Lock()
s.LastActivityAt = now
s.activityMu.Unlock()
}
func (s *SOCKSState) lastActivityAt() time.Time {
s.activityMu.RLock()
defer s.activityMu.RUnlock()
return s.LastActivityAt
}
func (s *Server) cleanupLoop(ctx context.Context) {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
@@ -903,7 +944,7 @@ func (s *Server) cleanupExpired() {
for clientSessionKey, session := range sessionSnapshots {
session.mu.Lock()
for socksID, socksState := range session.SOCKSConnections {
if now.Sub(socksState.LastActivityAt) > socksTTL {
if now.Sub(socksState.lastActivityAt()) > socksTTL {
targetAddress := targetAddressForLog(socksState.Target)
_ = socksState.closeUpstream()
socksState.release()
+63
View File
@@ -3,6 +3,7 @@ package server
import (
"errors"
"net"
"sync"
"testing"
"time"
@@ -256,6 +257,68 @@ func TestProcessBatchBlockedSessionDoesNotBlockOtherSessions(t *testing.T) {
}
}
func TestSOCKSStateNextOutboundSequenceIsConcurrentSafe(t *testing.T) {
socksState := &SOCKSState{}
const workers = 32
const iterationsPerWorker = 64
start := make(chan struct{})
var wg sync.WaitGroup
results := make(chan uint64, workers*iterationsPerWorker)
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
for j := 0; j < iterationsPerWorker; j++ {
results <- socksState.nextOutboundSequence()
}
}()
}
close(start)
wg.Wait()
close(results)
seen := make(map[uint64]struct{}, workers*iterationsPerWorker)
var maxSeq uint64
for seq := range results {
if _, exists := seen[seq]; exists {
t.Fatalf("duplicate outbound sequence generated: %d", seq)
}
seen[seq] = struct{}{}
if seq > maxSeq {
maxSeq = seq
}
}
expected := workers * iterationsPerWorker
if len(seen) != expected {
t.Fatalf("expected %d unique sequences, got %d", expected, len(seen))
}
if maxSeq != uint64(expected) {
t.Fatalf("expected max sequence %d, got %d", expected, maxSeq)
}
}
func TestSOCKSStateCloseUpstreamClearsConnectionSnapshot(t *testing.T) {
serverConn, clientConn := net.Pipe()
defer clientConn.Close()
socksState := &SOCKSState{}
socksState.setUpstreamConn(serverConn)
if err := socksState.closeUpstream(); err != nil {
t.Fatalf("expected closeUpstream to succeed, got %v", err)
}
if _, ok := socksState.currentUpstreamConn(); ok {
t.Fatal("expected upstream connection snapshot to be cleared after close")
}
}
func testDataPacket(clientSessionKey string, socksID uint64, sequence uint64, payload string) protocol.Packet {
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData)
packet.SOCKSID = socksID