Improve half-close handling and queue backpressure

This commit is contained in:
Amin.MasterkinG
2026-04-20 20:18:08 +03:30
parent 2baf5e8718
commit 025923fe89
7 changed files with 207 additions and 27 deletions
+1
View File
@@ -16,6 +16,7 @@ MAX_BATCH_BYTES = 262144
WORKER_COUNT = 4 WORKER_COUNT = 4
HTTP_REQUEST_TIMEOUT_MS = 15000 HTTP_REQUEST_TIMEOUT_MS = 15000
WORKER_POLL_INTERVAL_MS = 200 WORKER_POLL_INTERVAL_MS = 200
IDLE_POLL_INTERVAL_MS = 1000
MAX_QUEUE_BYTES_PER_SOCKS = 1048576 MAX_QUEUE_BYTES_PER_SOCKS = 1048576
ACK_TIMEOUT_MS = 5000 ACK_TIMEOUT_MS = 5000
MAX_RETRY_COUNT = 5 MAX_RETRY_COUNT = 5
+21 -2
View File
@@ -151,7 +151,7 @@ func (c *Client) buildPollBatch(connections []*SOCKSConnection) (protocol.Batch,
now := time.Now() now := time.Now()
nowUnixMS := now.UnixMilli() nowUnixMS := now.UnixMilli()
lastUnixMS := c.lastPollUnixMS.Load() lastUnixMS := c.lastPollUnixMS.Load()
minInterval := time.Duration(c.cfg.WorkerPollIntervalMS) * time.Millisecond minInterval := time.Duration(c.cfg.IdlePollIntervalMS) * time.Millisecond
if lastUnixMS > 0 && nowUnixMS-lastUnixMS < minInterval.Milliseconds() { if lastUnixMS > 0 && nowUnixMS-lastUnixMS < minInterval.Milliseconds() {
return protocol.Batch{}, false return protocol.Batch{}, false
} }
@@ -304,7 +304,26 @@ func (c *Client) applyResponsePacket(packet protocol.Packet) error {
socksConn.LastActivityAt = time.Now() socksConn.LastActivityAt = time.Now()
return socksConn.WriteToLocal(packet.Payload) return socksConn.WriteToLocal(packet.Payload)
case protocol.PacketTypeSOCKSCloseRead, protocol.PacketTypeSOCKSCloseWrite, protocol.PacketTypeSOCKSRST: case protocol.PacketTypeSOCKSCloseRead:
_ = socksConn.AckPacket(packet)
socksConn.LastActivityAt = time.Now()
if err := socksConn.CloseLocalWrite(); err != nil {
return err
}
if socksConn.BothLocalSidesClosed() {
return socksConn.CloseLocal()
}
return nil
case protocol.PacketTypeSOCKSCloseWrite:
_ = socksConn.AckPacket(packet)
socksConn.LastActivityAt = time.Now()
if socksConn.BothLocalSidesClosed() {
return socksConn.CloseLocal()
}
return nil
case protocol.PacketTypeSOCKSRST:
_ = socksConn.AckPacket(packet) _ = socksConn.AckPacket(packet)
socksConn.LastActivityAt = time.Now() socksConn.LastActivityAt = time.Now()
return socksConn.CloseLocal() return socksConn.CloseLocal()
+63 -4
View File
@@ -7,6 +7,7 @@
package client package client
import ( import (
"context"
"encoding/hex" "encoding/hex"
"net" "net"
"sync" "sync"
@@ -38,6 +39,11 @@ type SOCKSConnection struct {
LocalConn net.Conn LocalConn net.Conn
localWriteMu sync.Mutex localWriteMu sync.Mutex
localCloseMu sync.Mutex
localReadEOF bool
localWriteEOF bool
closedC chan struct{}
closeOnce sync.Once
connectResultC chan error connectResultC chan error
queueMu sync.Mutex queueMu sync.Mutex
OutboundQueue []*SOCKSOutboundQueueItem OutboundQueue []*SOCKSOutboundQueueItem
@@ -74,6 +80,7 @@ func (s *SOCKSConnectionStore) New(clientSessionKey string, clientAddress string
CreatedAt: now, CreatedAt: now,
LastActivityAt: now, LastActivityAt: now,
ClientAddress: clientAddress, ClientAddress: clientAddress,
closedC: make(chan struct{}),
connectResultC: make(chan error, 1), connectResultC: make(chan error, 1),
InFlight: make(map[string]*SOCKSOutboundQueueItem), InFlight: make(map[string]*SOCKSOutboundQueueItem),
} }
@@ -112,13 +119,65 @@ func (s *SOCKSConnection) WriteToLocal(payload []byte) error {
} }
func (s *SOCKSConnection) CloseLocal() error { func (s *SOCKSConnection) CloseLocal() error {
s.localWriteMu.Lock() var err error
defer s.localWriteMu.Unlock() s.closeOnce.Do(func() {
s.localWriteMu.Lock()
defer s.localWriteMu.Unlock()
if s.LocalConn != nil {
err = s.LocalConn.Close()
}
close(s.closedC)
})
return err
}
if s.LocalConn == nil { func (s *SOCKSConnection) CloseLocalWrite() error {
s.localCloseMu.Lock()
defer s.localCloseMu.Unlock()
if s.localWriteEOF {
return nil return nil
} }
return s.LocalConn.Close() s.localWriteEOF = true
if tcpConn, ok := s.LocalConn.(*net.TCPConn); ok {
return tcpConn.CloseWrite()
}
return s.CloseLocal()
}
func (s *SOCKSConnection) CloseLocalRead() error {
s.localCloseMu.Lock()
defer s.localCloseMu.Unlock()
if s.localReadEOF {
return nil
}
s.localReadEOF = true
if tcpConn, ok := s.LocalConn.(*net.TCPConn); ok {
return tcpConn.CloseRead()
}
return nil
}
func (s *SOCKSConnection) MarkLocalReadEOF() {
s.localCloseMu.Lock()
s.localReadEOF = true
s.localCloseMu.Unlock()
}
func (s *SOCKSConnection) BothLocalSidesClosed() bool {
s.localCloseMu.Lock()
defer s.localCloseMu.Unlock()
return s.localReadEOF && s.localWriteEOF
}
func (s *SOCKSConnection) WaitUntilClosed(ctx context.Context) {
select {
case <-ctx.Done():
case <-s.closedC:
}
} }
func (s *SOCKSConnectionStore) Get(id uint64) *SOCKSConnection { func (s *SOCKSConnectionStore) Get(id uint64) *SOCKSConnection {
+2
View File
@@ -316,7 +316,9 @@ func (c *Client) captureInitialPayload(ctx context.Context, conn net.Conn, socks
if err != nil { if err != nil {
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
socksConn.MarkLocalReadEOF()
_ = socksConn.EnqueuePacket(socksConn.BuildSOCKSCloseWritePacket()) _ = socksConn.EnqueuePacket(socksConn.BuildSOCKSCloseWritePacket())
socksConn.WaitUntilClosed(ctx)
return nil return nil
} }
if ne, ok := err.(net.Error); ok && ne.Timeout() { if ne, ok := err.(net.Error); ok && ne.Timeout() {
+22
View File
@@ -32,12 +32,14 @@ type Config struct {
WorkerCount int WorkerCount int
HTTPRequestTimeoutMS int HTTPRequestTimeoutMS int
WorkerPollIntervalMS int WorkerPollIntervalMS int
IdlePollIntervalMS int
MaxQueueBytesPerSOCKS int MaxQueueBytesPerSOCKS int
AckTimeoutMS int AckTimeoutMS int
MaxRetryCount int MaxRetryCount int
SessionIdleTimeoutMS int SessionIdleTimeoutMS int
SOCKSIdleTimeoutMS int SOCKSIdleTimeoutMS int
ReadBodyLimitBytes int ReadBodyLimitBytes int
MaxServerQueueBytes int
} }
func Load(path string) (Config, error) { func Load(path string) (Config, error) {
@@ -53,12 +55,14 @@ func Load(path string) (Config, error) {
WorkerCount: 4, WorkerCount: 4,
HTTPRequestTimeoutMS: 15000, HTTPRequestTimeoutMS: 15000,
WorkerPollIntervalMS: 200, WorkerPollIntervalMS: 200,
IdlePollIntervalMS: 1000,
MaxQueueBytesPerSOCKS: 1024 * 1024, MaxQueueBytesPerSOCKS: 1024 * 1024,
AckTimeoutMS: 5000, AckTimeoutMS: 5000,
MaxRetryCount: 5, MaxRetryCount: 5,
SessionIdleTimeoutMS: 5 * 60 * 1000, SessionIdleTimeoutMS: 5 * 60 * 1000,
SOCKSIdleTimeoutMS: 2 * 60 * 1000, SOCKSIdleTimeoutMS: 2 * 60 * 1000,
ReadBodyLimitBytes: 2 * 1024 * 1024, ReadBodyLimitBytes: 2 * 1024 * 1024,
MaxServerQueueBytes: 2 * 1024 * 1024,
} }
file, err := os.Open(path) file, err := os.Open(path)
@@ -151,6 +155,12 @@ func Load(path string) (Config, error) {
return Config{}, fmt.Errorf("parse WORKER_POLL_INTERVAL_MS: %w", err) return Config{}, fmt.Errorf("parse WORKER_POLL_INTERVAL_MS: %w", err)
} }
cfg.WorkerPollIntervalMS = interval cfg.WorkerPollIntervalMS = interval
case "IDLE_POLL_INTERVAL_MS":
interval, err := strconv.Atoi(value)
if err != nil {
return Config{}, fmt.Errorf("parse IDLE_POLL_INTERVAL_MS: %w", err)
}
cfg.IdlePollIntervalMS = interval
case "MAX_QUEUE_BYTES_PER_SOCKS": case "MAX_QUEUE_BYTES_PER_SOCKS":
size, err := strconv.Atoi(value) size, err := strconv.Atoi(value)
if err != nil { if err != nil {
@@ -187,6 +197,12 @@ func Load(path string) (Config, error) {
return Config{}, fmt.Errorf("parse READ_BODY_LIMIT_BYTES: %w", err) return Config{}, fmt.Errorf("parse READ_BODY_LIMIT_BYTES: %w", err)
} }
cfg.ReadBodyLimitBytes = size cfg.ReadBodyLimitBytes = size
case "MAX_SERVER_QUEUE_BYTES":
size, err := strconv.Atoi(value)
if err != nil {
return Config{}, fmt.Errorf("parse MAX_SERVER_QUEUE_BYTES: %w", err)
}
cfg.MaxServerQueueBytes = size
} }
} }
@@ -216,6 +232,9 @@ func (c Config) ValidateClient() error {
if c.WorkerPollIntervalMS < 1 { if c.WorkerPollIntervalMS < 1 {
return fmt.Errorf("invalid WORKER_POLL_INTERVAL_MS: %d", c.WorkerPollIntervalMS) return fmt.Errorf("invalid WORKER_POLL_INTERVAL_MS: %d", c.WorkerPollIntervalMS)
} }
if c.IdlePollIntervalMS < c.WorkerPollIntervalMS {
return fmt.Errorf("IDLE_POLL_INTERVAL_MS must be >= WORKER_POLL_INTERVAL_MS")
}
if c.AckTimeoutMS < 1 { if c.AckTimeoutMS < 1 {
return fmt.Errorf("invalid ACK_TIMEOUT_MS: %d", c.AckTimeoutMS) return fmt.Errorf("invalid ACK_TIMEOUT_MS: %d", c.AckTimeoutMS)
} }
@@ -244,6 +263,9 @@ func (c Config) ValidateServer() error {
if c.ReadBodyLimitBytes < c.MaxChunkSize { if c.ReadBodyLimitBytes < c.MaxChunkSize {
return fmt.Errorf("READ_BODY_LIMIT_BYTES must be >= MAX_CHUNK_SIZE") return fmt.Errorf("READ_BODY_LIMIT_BYTES must be >= MAX_CHUNK_SIZE")
} }
if c.MaxServerQueueBytes < c.MaxChunkSize {
return fmt.Errorf("MAX_SERVER_QUEUE_BYTES must be >= MAX_CHUNK_SIZE")
}
return nil return nil
} }
+97 -21
View File
@@ -13,6 +13,7 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
@@ -51,9 +52,13 @@ type SOCKSState struct {
OutboundSequence uint64 OutboundSequence uint64
UpstreamConn net.Conn UpstreamConn net.Conn
upstreamWriteMu sync.Mutex upstreamWriteMu sync.Mutex
upstreamCloseMu sync.Mutex
upstreamReadEOF bool
upstreamWriteEOF bool
queueMu sync.Mutex queueMu sync.Mutex
OutboundQueue []protocol.Packet OutboundQueue []protocol.Packet
QueuedBytes int QueuedBytes int
MaxQueueBytes int
} }
func New(cfg config.Config, lg *logger.Logger) *Server { func New(cfg config.Config, lg *logger.Logger) *Server {
@@ -215,6 +220,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
Target: packet.Target, Target: packet.Target,
ConnectSeen: true, ConnectSeen: true,
LastSequenceSeen: packet.Sequence, LastSequenceSeen: packet.Sequence,
MaxQueueBytes: s.cfg.MaxServerQueueBytes,
} }
session.SOCKSConnections[packet.SOCKSID] = socksState session.SOCKSConnections[packet.SOCKSID] = socksState
} else { } else {
@@ -301,7 +307,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
"<gray>received close_read socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>", "<gray>received close_read socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Sequence, session.ClientSessionKey, packet.SOCKSID, packet.Sequence, session.ClientSessionKey,
) )
_ = socksState.closeUpstream() _ = socksState.closeUpstreamRead()
return &response, nil return &response, nil
case protocol.PacketTypeSOCKSCloseWrite: case protocol.PacketTypeSOCKSCloseWrite:
@@ -318,7 +324,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
"<gray>received close_write socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>", "<gray>received close_write socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Sequence, session.ClientSessionKey, packet.SOCKSID, packet.Sequence, session.ClientSessionKey,
) )
_ = socksState.closeUpstream() _ = socksState.closeUpstreamWrite()
return &response, nil return &response, nil
case protocol.PacketTypeSOCKSRST: case protocol.PacketTypeSOCKSRST:
@@ -408,6 +414,7 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot
LastActivityAt: now, LastActivityAt: now,
Target: packet.Target, Target: packet.Target,
LastSequenceSeen: packet.Sequence, LastSequenceSeen: packet.Sequence,
MaxQueueBytes: s.cfg.MaxServerQueueBytes,
} }
session.SOCKSConnections[packet.SOCKSID] = socksState session.SOCKSConnections[packet.SOCKSID] = socksState
s.log.Debugf( s.log.Debugf(
@@ -442,7 +449,15 @@ func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSStat
n, err := socksState.UpstreamConn.Read(buffer) n, err := socksState.UpstreamConn.Read(buffer)
if n > 0 { if n > 0 {
chunk := append([]byte(nil), buffer[:n]...) chunk := append([]byte(nil), buffer[:n]...)
socksState.enqueueOutboundData(clientSessionKey, chunk, false) if !socksState.enqueueOutboundData(clientSessionKey, chunk, false) {
s.log.Warnf(
"<yellow>server outbound queue full socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></yellow>",
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey,
)
socksState.forceResetPacket(clientSessionKey)
_ = socksState.closeUpstream()
return
}
socksState.LastActivityAt = time.Now() socksState.LastActivityAt = time.Now()
queueDepth, queueBytes := socksState.queueSnapshot() queueDepth, queueBytes := socksState.queueSnapshot()
s.log.Debugf( s.log.Debugf(
@@ -457,14 +472,22 @@ func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSStat
"<gray>upstream eof socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></gray>", "<gray>upstream eof socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></gray>",
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey, socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey,
) )
socksState.enqueueControlPacket(clientSessionKey, protocol.PacketTypeSOCKSCloseRead, true) _ = socksState.enqueueControlPacket(clientSessionKey, protocol.PacketTypeSOCKSCloseRead, true)
} else { _ = socksState.closeUpstreamRead()
s.log.Warnf( return
"<yellow>upstream read failed socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan> error=<cyan>%v</cyan></yellow>",
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey, err,
)
socksState.enqueueControlPacket(clientSessionKey, protocol.PacketTypeSOCKSRST, true)
} }
if isClosedConnError(err) {
s.log.Debugf(
"<gray>upstream closed locally socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></gray>",
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey,
)
return
}
s.log.Warnf(
"<yellow>upstream read failed socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan> error=<cyan>%v</cyan></yellow>",
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey, err,
)
socksState.forceResetPacket(clientSessionKey)
_ = socksState.closeUpstream() _ = socksState.closeUpstream()
return return
} }
@@ -476,28 +499,45 @@ func (s *SOCKSState) nextOutboundSequence() uint64 {
return s.OutboundSequence return s.OutboundSequence
} }
func (s *SOCKSState) enqueueOutboundData(clientSessionKey string, payload []byte, final bool) { func (s *SOCKSState) enqueueOutboundData(clientSessionKey string, payload []byte, final bool) bool {
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData) packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData)
packet.SOCKSID = s.ID packet.SOCKSID = s.ID
packet.Sequence = s.nextOutboundSequence() packet.Sequence = s.nextOutboundSequence()
packet.Final = final packet.Final = final
packet.Payload = payload packet.Payload = payload
s.enqueuePacket(packet) return s.enqueuePacket(packet)
} }
func (s *SOCKSState) enqueueControlPacket(clientSessionKey string, packetType protocol.PacketType, final bool) { func (s *SOCKSState) enqueueControlPacket(clientSessionKey string, packetType protocol.PacketType, final bool) bool {
packet := protocol.NewPacket(clientSessionKey, packetType) packet := protocol.NewPacket(clientSessionKey, packetType)
packet.SOCKSID = s.ID packet.SOCKSID = s.ID
packet.Sequence = s.nextOutboundSequence() packet.Sequence = s.nextOutboundSequence()
packet.Final = final packet.Final = final
s.enqueuePacket(packet) return s.enqueuePacket(packet)
} }
func (s *SOCKSState) enqueuePacket(packet protocol.Packet) { func (s *SOCKSState) enqueuePacket(packet protocol.Packet) bool {
s.queueMu.Lock() s.queueMu.Lock()
defer s.queueMu.Unlock() defer s.queueMu.Unlock()
packetBytes := len(packet.Payload)
if s.MaxQueueBytes > 0 && s.QueuedBytes+packetBytes > s.MaxQueueBytes {
return false
}
s.OutboundQueue = append(s.OutboundQueue, packet) s.OutboundQueue = append(s.OutboundQueue, packet)
s.QueuedBytes += len(packet.Payload) s.QueuedBytes += packetBytes
return true
}
func (s *SOCKSState) forceResetPacket(clientSessionKey string) {
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSRST)
packet.SOCKSID = s.ID
packet.Sequence = s.nextOutboundSequence()
packet.Final = true
s.queueMu.Lock()
defer s.queueMu.Unlock()
s.OutboundQueue = []protocol.Packet{packet}
s.QueuedBytes = 0
} }
func (s *SOCKSState) queueSnapshot() (items int, bytes int) { func (s *SOCKSState) queueSnapshot() (items int, bytes int) {
@@ -552,18 +592,47 @@ func (s *SOCKSState) writeUpstream(payload []byte) error {
} }
func (s *SOCKSState) closeUpstream() error { func (s *SOCKSState) closeUpstream() error {
s.upstreamWriteMu.Lock() s.upstreamCloseMu.Lock()
defer s.upstreamWriteMu.Unlock() defer s.upstreamCloseMu.Unlock()
if s.UpstreamConn == nil { if s.UpstreamConn == nil {
return nil return nil
} }
target := targetAddressForLog(s.Target)
err := s.UpstreamConn.Close() err := s.UpstreamConn.Close()
s.UpstreamConn = nil s.UpstreamConn = nil
if err == nil { s.upstreamReadEOF = true
s.upstreamWriteEOF = true
return err
}
func (s *SOCKSState) closeUpstreamRead() error {
s.upstreamCloseMu.Lock()
defer s.upstreamCloseMu.Unlock()
if s.upstreamReadEOF {
return nil return nil
} }
_ = target s.upstreamReadEOF = true
if tcpConn, ok := s.UpstreamConn.(*net.TCPConn); ok {
return tcpConn.CloseRead()
}
return nil
}
func (s *SOCKSState) closeUpstreamWrite() error {
s.upstreamCloseMu.Lock()
defer s.upstreamCloseMu.Unlock()
if s.upstreamWriteEOF {
return nil
}
s.upstreamWriteEOF = true
if tcpConn, ok := s.UpstreamConn.(*net.TCPConn); ok {
return tcpConn.CloseWrite()
}
if s.UpstreamConn == nil {
return nil
}
err := s.UpstreamConn.Close()
s.UpstreamConn = nil
s.upstreamReadEOF = true
return err return err
} }
@@ -626,3 +695,10 @@ func targetAddressForLog(target *protocol.Target) string {
} }
return target.Address() return target.Address()
} }
func isClosedConnError(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), "use of closed network connection")
}
+1
View File
@@ -12,4 +12,5 @@ WORKER_COUNT = 4
SESSION_IDLE_TIMEOUT_MS = 300000 SESSION_IDLE_TIMEOUT_MS = 300000
SOCKS_IDLE_TIMEOUT_MS = 120000 SOCKS_IDLE_TIMEOUT_MS = 120000
READ_BODY_LIMIT_BYTES = 2097152 READ_BODY_LIMIT_BYTES = 2097152
MAX_SERVER_QUEUE_BYTES = 2097152
# ============================================================================== # ==============================================================================