Improve half-close handling and queue backpressure

This commit is contained in:
Amin.MasterkinG
2026-04-20 20:18:08 +03:30
parent 2baf5e8718
commit 025923fe89
7 changed files with 207 additions and 27 deletions
+97 -21
View File
@@ -13,6 +13,7 @@ import (
"io"
"net"
"net/http"
"strings"
"sync"
"time"
@@ -51,9 +52,13 @@ type SOCKSState struct {
OutboundSequence uint64
UpstreamConn net.Conn
upstreamWriteMu sync.Mutex
upstreamCloseMu sync.Mutex
upstreamReadEOF bool
upstreamWriteEOF bool
queueMu sync.Mutex
OutboundQueue []protocol.Packet
QueuedBytes int
MaxQueueBytes int
}
func New(cfg config.Config, lg *logger.Logger) *Server {
@@ -215,6 +220,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
Target: packet.Target,
ConnectSeen: true,
LastSequenceSeen: packet.Sequence,
MaxQueueBytes: s.cfg.MaxServerQueueBytes,
}
session.SOCKSConnections[packet.SOCKSID] = socksState
} else {
@@ -301,7 +307,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
"<gray>received close_read socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Sequence, session.ClientSessionKey,
)
_ = socksState.closeUpstream()
_ = socksState.closeUpstreamRead()
return &response, nil
case protocol.PacketTypeSOCKSCloseWrite:
@@ -318,7 +324,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
"<gray>received close_write socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Sequence, session.ClientSessionKey,
)
_ = socksState.closeUpstream()
_ = socksState.closeUpstreamWrite()
return &response, nil
case protocol.PacketTypeSOCKSRST:
@@ -408,6 +414,7 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot
LastActivityAt: now,
Target: packet.Target,
LastSequenceSeen: packet.Sequence,
MaxQueueBytes: s.cfg.MaxServerQueueBytes,
}
session.SOCKSConnections[packet.SOCKSID] = socksState
s.log.Debugf(
@@ -442,7 +449,15 @@ func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSStat
n, err := socksState.UpstreamConn.Read(buffer)
if n > 0 {
chunk := append([]byte(nil), buffer[:n]...)
socksState.enqueueOutboundData(clientSessionKey, chunk, false)
if !socksState.enqueueOutboundData(clientSessionKey, chunk, false) {
s.log.Warnf(
"<yellow>server outbound queue full socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></yellow>",
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey,
)
socksState.forceResetPacket(clientSessionKey)
_ = socksState.closeUpstream()
return
}
socksState.LastActivityAt = time.Now()
queueDepth, queueBytes := socksState.queueSnapshot()
s.log.Debugf(
@@ -457,14 +472,22 @@ func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSStat
"<gray>upstream eof socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></gray>",
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey,
)
socksState.enqueueControlPacket(clientSessionKey, protocol.PacketTypeSOCKSCloseRead, true)
} else {
s.log.Warnf(
"<yellow>upstream read failed socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan> error=<cyan>%v</cyan></yellow>",
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey, err,
)
socksState.enqueueControlPacket(clientSessionKey, protocol.PacketTypeSOCKSRST, true)
_ = socksState.enqueueControlPacket(clientSessionKey, protocol.PacketTypeSOCKSCloseRead, true)
_ = socksState.closeUpstreamRead()
return
}
if isClosedConnError(err) {
s.log.Debugf(
"<gray>upstream closed locally socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></gray>",
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey,
)
return
}
s.log.Warnf(
"<yellow>upstream read failed socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan> error=<cyan>%v</cyan></yellow>",
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey, err,
)
socksState.forceResetPacket(clientSessionKey)
_ = socksState.closeUpstream()
return
}
@@ -476,28 +499,45 @@ func (s *SOCKSState) nextOutboundSequence() uint64 {
return s.OutboundSequence
}
func (s *SOCKSState) enqueueOutboundData(clientSessionKey string, payload []byte, final bool) {
func (s *SOCKSState) enqueueOutboundData(clientSessionKey string, payload []byte, final bool) bool {
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData)
packet.SOCKSID = s.ID
packet.Sequence = s.nextOutboundSequence()
packet.Final = final
packet.Payload = payload
s.enqueuePacket(packet)
return s.enqueuePacket(packet)
}
func (s *SOCKSState) enqueueControlPacket(clientSessionKey string, packetType protocol.PacketType, final bool) {
func (s *SOCKSState) enqueueControlPacket(clientSessionKey string, packetType protocol.PacketType, final bool) bool {
packet := protocol.NewPacket(clientSessionKey, packetType)
packet.SOCKSID = s.ID
packet.Sequence = s.nextOutboundSequence()
packet.Final = final
s.enqueuePacket(packet)
return s.enqueuePacket(packet)
}
func (s *SOCKSState) enqueuePacket(packet protocol.Packet) {
func (s *SOCKSState) enqueuePacket(packet protocol.Packet) bool {
s.queueMu.Lock()
defer s.queueMu.Unlock()
packetBytes := len(packet.Payload)
if s.MaxQueueBytes > 0 && s.QueuedBytes+packetBytes > s.MaxQueueBytes {
return false
}
s.OutboundQueue = append(s.OutboundQueue, packet)
s.QueuedBytes += len(packet.Payload)
s.QueuedBytes += packetBytes
return true
}
func (s *SOCKSState) forceResetPacket(clientSessionKey string) {
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSRST)
packet.SOCKSID = s.ID
packet.Sequence = s.nextOutboundSequence()
packet.Final = true
s.queueMu.Lock()
defer s.queueMu.Unlock()
s.OutboundQueue = []protocol.Packet{packet}
s.QueuedBytes = 0
}
func (s *SOCKSState) queueSnapshot() (items int, bytes int) {
@@ -552,18 +592,47 @@ func (s *SOCKSState) writeUpstream(payload []byte) error {
}
func (s *SOCKSState) closeUpstream() error {
s.upstreamWriteMu.Lock()
defer s.upstreamWriteMu.Unlock()
s.upstreamCloseMu.Lock()
defer s.upstreamCloseMu.Unlock()
if s.UpstreamConn == nil {
return nil
}
target := targetAddressForLog(s.Target)
err := s.UpstreamConn.Close()
s.UpstreamConn = nil
if err == nil {
s.upstreamReadEOF = true
s.upstreamWriteEOF = true
return err
}
func (s *SOCKSState) closeUpstreamRead() error {
s.upstreamCloseMu.Lock()
defer s.upstreamCloseMu.Unlock()
if s.upstreamReadEOF {
return nil
}
_ = target
s.upstreamReadEOF = true
if tcpConn, ok := s.UpstreamConn.(*net.TCPConn); ok {
return tcpConn.CloseRead()
}
return nil
}
func (s *SOCKSState) closeUpstreamWrite() error {
s.upstreamCloseMu.Lock()
defer s.upstreamCloseMu.Unlock()
if s.upstreamWriteEOF {
return nil
}
s.upstreamWriteEOF = true
if tcpConn, ok := s.UpstreamConn.(*net.TCPConn); ok {
return tcpConn.CloseWrite()
}
if s.UpstreamConn == nil {
return nil
}
err := s.UpstreamConn.Close()
s.UpstreamConn = nil
s.upstreamReadEOF = true
return err
}
@@ -626,3 +695,10 @@ func targetAddressForLog(target *protocol.Target) string {
}
return target.Address()
}
func isClosedConnError(err error) bool {
if err == nil {
return false
}
return strings.Contains(err.Error(), "use of closed network connection")
}