mirror of
https://github.com/masterking32/MasterHttpRelayVPN.git
synced 2026-05-18 05:44:35 +03:00
Make server outbound sequencing and upstream connection lifecycle concurrency-safe
This commit is contained in:
+69
-28
@@ -15,6 +15,7 @@ import (
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"masterhttprelayvpn/internal/config"
|
||||
@@ -56,8 +57,9 @@ type SOCKSState struct {
|
||||
OutboundSequence uint64
|
||||
PendingInbound map[uint64][]PendingInboundPacket
|
||||
UpstreamConn net.Conn
|
||||
activityMu sync.RWMutex
|
||||
upstreamStateMu sync.RWMutex
|
||||
upstreamWriteMu sync.Mutex
|
||||
upstreamCloseMu sync.Mutex
|
||||
upstreamReadEOF bool
|
||||
upstreamWriteEOF bool
|
||||
queueMu sync.Mutex
|
||||
@@ -237,7 +239,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
|
||||
}
|
||||
session.SOCKSConnections[packet.SOCKSID] = socksState
|
||||
} else {
|
||||
socksState.LastActivityAt = now
|
||||
socksState.setLastActivityAt(now)
|
||||
socksState.Target = packet.Target
|
||||
socksState.ConnectSeen = true
|
||||
if socksState.PendingInbound == nil {
|
||||
@@ -248,7 +250,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
|
||||
}
|
||||
}
|
||||
|
||||
if socksState.UpstreamConn == nil {
|
||||
if _, connected := socksState.currentUpstreamConn(); !connected {
|
||||
s.log.Debugf(
|
||||
"<gray>dial upstream socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></gray>",
|
||||
packet.SOCKSID, packet.Target.Address(), session.ClientSessionKey,
|
||||
@@ -267,7 +269,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
|
||||
return []protocol.Packet{response}, nil
|
||||
}
|
||||
|
||||
socksState.UpstreamConn = upstreamConn
|
||||
socksState.setUpstreamConn(upstreamConn)
|
||||
socksState.ConnectAcked = true
|
||||
s.log.Infof(
|
||||
"<green>upstream connected socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></green>",
|
||||
@@ -408,7 +410,7 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot
|
||||
}
|
||||
|
||||
func (s *Server) applyOrderedPacketLocked(session *ClientSession, socksState *SOCKSState, packet protocol.Packet, now time.Time) ([]protocol.Packet, error) {
|
||||
socksState.LastActivityAt = now
|
||||
socksState.setLastActivityAt(now)
|
||||
if packet.Sequence > socksState.LastSequenceSeen {
|
||||
socksState.LastSequenceSeen = packet.Sequence
|
||||
}
|
||||
@@ -559,13 +561,14 @@ func (s *Server) drainSessionOutboundLocked(session *ClientSession) []protocol.P
|
||||
}
|
||||
|
||||
func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSState) {
|
||||
if socksState.UpstreamConn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
buffer := make([]byte, s.cfg.MaxChunkSize)
|
||||
for {
|
||||
n, err := socksState.UpstreamConn.Read(buffer)
|
||||
upstreamConn, ok := socksState.currentUpstreamConn()
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
n, err := upstreamConn.Read(buffer)
|
||||
if n > 0 {
|
||||
chunk := append([]byte(nil), buffer[:n]...)
|
||||
if !socksState.enqueueOutboundData(clientSessionKey, chunk, false) {
|
||||
@@ -577,7 +580,7 @@ func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSStat
|
||||
_ = socksState.closeUpstream()
|
||||
return
|
||||
}
|
||||
socksState.LastActivityAt = time.Now()
|
||||
socksState.setLastActivityAt(time.Now())
|
||||
queueDepth, queueBytes := socksState.queueSnapshot()
|
||||
s.log.Debugf(
|
||||
"<gray>read upstream socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> bytes=<cyan>%d</cyan> queue_depth=<cyan>%d</cyan> queue_bytes=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
|
||||
@@ -614,8 +617,7 @@ func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSStat
|
||||
}
|
||||
|
||||
func (s *SOCKSState) nextOutboundSequence() uint64 {
|
||||
s.OutboundSequence++
|
||||
return s.OutboundSequence
|
||||
return atomic.AddUint64(&s.OutboundSequence, 1)
|
||||
}
|
||||
|
||||
func (s *SOCKSState) expectedInboundSequenceLocked() uint64 {
|
||||
@@ -821,58 +823,97 @@ func (s *SOCKSState) writeUpstream(payload []byte) error {
|
||||
}
|
||||
s.upstreamWriteMu.Lock()
|
||||
defer s.upstreamWriteMu.Unlock()
|
||||
if s.UpstreamConn == nil {
|
||||
upstreamConn, ok := s.currentUpstreamConn()
|
||||
if !ok {
|
||||
return fmt.Errorf("upstream connection is not established")
|
||||
}
|
||||
_, err := s.UpstreamConn.Write(payload)
|
||||
_, err := upstreamConn.Write(payload)
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SOCKSState) closeUpstream() error {
|
||||
s.upstreamCloseMu.Lock()
|
||||
defer s.upstreamCloseMu.Unlock()
|
||||
if s.UpstreamConn == nil {
|
||||
s.upstreamStateMu.Lock()
|
||||
upstreamConn := s.UpstreamConn
|
||||
if upstreamConn == nil {
|
||||
s.upstreamStateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
err := s.UpstreamConn.Close()
|
||||
s.UpstreamConn = nil
|
||||
s.upstreamReadEOF = true
|
||||
s.upstreamWriteEOF = true
|
||||
s.upstreamStateMu.Unlock()
|
||||
err := upstreamConn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SOCKSState) closeUpstreamRead() error {
|
||||
s.upstreamCloseMu.Lock()
|
||||
defer s.upstreamCloseMu.Unlock()
|
||||
s.upstreamStateMu.Lock()
|
||||
if s.upstreamReadEOF {
|
||||
s.upstreamStateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.upstreamReadEOF = true
|
||||
if tcpConn, ok := s.UpstreamConn.(*net.TCPConn); ok {
|
||||
upstreamConn := s.UpstreamConn
|
||||
s.upstreamStateMu.Unlock()
|
||||
if tcpConn, ok := upstreamConn.(*net.TCPConn); ok {
|
||||
return tcpConn.CloseRead()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SOCKSState) closeUpstreamWrite() error {
|
||||
s.upstreamCloseMu.Lock()
|
||||
defer s.upstreamCloseMu.Unlock()
|
||||
s.upstreamStateMu.Lock()
|
||||
if s.upstreamWriteEOF {
|
||||
s.upstreamStateMu.Unlock()
|
||||
return nil
|
||||
}
|
||||
s.upstreamWriteEOF = true
|
||||
if tcpConn, ok := s.UpstreamConn.(*net.TCPConn); ok {
|
||||
upstreamConn := s.UpstreamConn
|
||||
s.upstreamStateMu.Unlock()
|
||||
if tcpConn, ok := upstreamConn.(*net.TCPConn); ok {
|
||||
return tcpConn.CloseWrite()
|
||||
}
|
||||
if s.UpstreamConn == nil {
|
||||
if upstreamConn == nil {
|
||||
return nil
|
||||
}
|
||||
err := s.UpstreamConn.Close()
|
||||
|
||||
err := upstreamConn.Close()
|
||||
s.upstreamStateMu.Lock()
|
||||
s.UpstreamConn = nil
|
||||
s.upstreamReadEOF = true
|
||||
s.upstreamStateMu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SOCKSState) setUpstreamConn(conn net.Conn) {
|
||||
s.upstreamStateMu.Lock()
|
||||
defer s.upstreamStateMu.Unlock()
|
||||
s.UpstreamConn = conn
|
||||
s.upstreamReadEOF = false
|
||||
s.upstreamWriteEOF = false
|
||||
}
|
||||
|
||||
func (s *SOCKSState) currentUpstreamConn() (net.Conn, bool) {
|
||||
s.upstreamStateMu.RLock()
|
||||
defer s.upstreamStateMu.RUnlock()
|
||||
if s.UpstreamConn == nil {
|
||||
return nil, false
|
||||
}
|
||||
return s.UpstreamConn, true
|
||||
}
|
||||
|
||||
func (s *SOCKSState) setLastActivityAt(now time.Time) {
|
||||
s.activityMu.Lock()
|
||||
s.LastActivityAt = now
|
||||
s.activityMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *SOCKSState) lastActivityAt() time.Time {
|
||||
s.activityMu.RLock()
|
||||
defer s.activityMu.RUnlock()
|
||||
return s.LastActivityAt
|
||||
}
|
||||
|
||||
func (s *Server) cleanupLoop(ctx context.Context) {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
@@ -903,7 +944,7 @@ func (s *Server) cleanupExpired() {
|
||||
for clientSessionKey, session := range sessionSnapshots {
|
||||
session.mu.Lock()
|
||||
for socksID, socksState := range session.SOCKSConnections {
|
||||
if now.Sub(socksState.LastActivityAt) > socksTTL {
|
||||
if now.Sub(socksState.lastActivityAt()) > socksTTL {
|
||||
targetAddress := targetAddressForLog(socksState.Target)
|
||||
_ = socksState.closeUpstream()
|
||||
socksState.release()
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -256,6 +257,68 @@ func TestProcessBatchBlockedSessionDoesNotBlockOtherSessions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSOCKSStateNextOutboundSequenceIsConcurrentSafe(t *testing.T) {
|
||||
socksState := &SOCKSState{}
|
||||
|
||||
const workers = 32
|
||||
const iterationsPerWorker = 64
|
||||
|
||||
start := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
results := make(chan uint64, workers*iterationsPerWorker)
|
||||
|
||||
for i := 0; i < workers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
for j := 0; j < iterationsPerWorker; j++ {
|
||||
results <- socksState.nextOutboundSequence()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
seen := make(map[uint64]struct{}, workers*iterationsPerWorker)
|
||||
var maxSeq uint64
|
||||
for seq := range results {
|
||||
if _, exists := seen[seq]; exists {
|
||||
t.Fatalf("duplicate outbound sequence generated: %d", seq)
|
||||
}
|
||||
seen[seq] = struct{}{}
|
||||
if seq > maxSeq {
|
||||
maxSeq = seq
|
||||
}
|
||||
}
|
||||
|
||||
expected := workers * iterationsPerWorker
|
||||
if len(seen) != expected {
|
||||
t.Fatalf("expected %d unique sequences, got %d", expected, len(seen))
|
||||
}
|
||||
if maxSeq != uint64(expected) {
|
||||
t.Fatalf("expected max sequence %d, got %d", expected, maxSeq)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSOCKSStateCloseUpstreamClearsConnectionSnapshot(t *testing.T) {
|
||||
serverConn, clientConn := net.Pipe()
|
||||
defer clientConn.Close()
|
||||
|
||||
socksState := &SOCKSState{}
|
||||
socksState.setUpstreamConn(serverConn)
|
||||
|
||||
if err := socksState.closeUpstream(); err != nil {
|
||||
t.Fatalf("expected closeUpstream to succeed, got %v", err)
|
||||
}
|
||||
|
||||
if _, ok := socksState.currentUpstreamConn(); ok {
|
||||
t.Fatal("expected upstream connection snapshot to be cleared after close")
|
||||
}
|
||||
}
|
||||
|
||||
func testDataPacket(clientSessionKey string, socksID uint64, sequence uint64, payload string) protocol.Packet {
|
||||
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData)
|
||||
packet.SOCKSID = socksID
|
||||
|
||||
Reference in New Issue
Block a user