Add inbound packet reordering with gap timeout and buffer limits

This commit is contained in:
Amin.MasterkinG
2026-04-21 10:14:47 +03:30
parent fa9311406a
commit 136ddef09a
10 changed files with 623 additions and 88 deletions
+230 -84
View File
@@ -50,7 +50,9 @@ type SOCKSState struct {
ResetSeen bool
ReceivedBytes uint64
LastSequenceSeen uint64
NextInboundSequence uint64
OutboundSequence uint64
PendingInbound map[uint64]PendingInboundPacket
UpstreamConn net.Conn
upstreamWriteMu sync.Mutex
upstreamCloseMu sync.Mutex
@@ -62,6 +64,11 @@ type SOCKSState struct {
MaxQueueBytes int
}
type PendingInboundPacket struct {
Packet protocol.Packet
QueuedAt time.Time
}
func New(cfg config.Config, lg *logger.Logger) *Server {
return &Server{
cfg: cfg,
@@ -168,22 +175,23 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) {
session.LastActivityAt = now
responses := make([]protocol.Packet, 0, len(batch.Packets))
responses = append(responses, s.expireReorderGapsLocked(session, now)...)
for _, packet := range batch.Packets {
s.log.Debugf(
"<gray>processing batch=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan> packet=<cyan>%s</cyan> socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> payload_bytes=<cyan>%d</cyan> final=<cyan>%t</cyan></gray>",
batch.BatchID, batch.ClientSessionKey, packet.Type, packet.SOCKSID, packet.Sequence, len(packet.Payload), packet.Final,
)
response, err := s.processPacketLocked(session, packet, now)
orderedResponses, err := s.processPacketLocked(session, packet, now)
if err != nil {
s.mu.Unlock()
return protocol.Batch{}, err
}
if response != nil {
for _, response := range orderedResponses {
s.log.Debugf(
"<gray>generated direct response packet=<cyan>%s</cyan> socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> payload_bytes=<cyan>%d</cyan></gray>",
response.Type, response.SOCKSID, response.Sequence, len(response.Payload),
)
responses = append(responses, *response)
responses = append(responses, response)
}
}
for _, outbound := range s.drainSessionOutboundLocked(session) {
@@ -201,7 +209,7 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) {
return protocol.NewBatch(batch.ClientSessionKey, protocol.NewBatchID(), responses), nil
}
func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Packet, now time.Time) (*protocol.Packet, error) {
func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Packet, now time.Time) ([]protocol.Packet, error) {
if packet.ClientSessionKey != session.ClientSessionKey {
return nil, fmt.Errorf("packet client session key mismatch")
}
@@ -222,12 +230,16 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
ConnectSeen: true,
LastSequenceSeen: packet.Sequence,
MaxQueueBytes: s.cfg.MaxServerQueueBytes,
PendingInbound: make(map[uint64]PendingInboundPacket),
}
session.SOCKSConnections[packet.SOCKSID] = socksState
} else {
socksState.LastActivityAt = now
socksState.Target = packet.Target
socksState.ConnectSeen = true
if socksState.PendingInbound == nil {
socksState.PendingInbound = make(map[uint64]PendingInboundPacket)
}
if packet.Sequence > socksState.LastSequenceSeen {
socksState.LastSequenceSeen = packet.Sequence
}
@@ -248,7 +260,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
response.Payload = []byte(err.Error())
return &response, nil
return []protocol.Packet{response}, nil
}
socksState.UpstreamConn = upstreamConn
socksState.ConnectAcked = true
@@ -262,90 +274,48 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSConnectAck)
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
return &response, nil
case protocol.PacketTypeSOCKSData:
socksState := s.getOrCreateSOCKSStateLocked(session, packet, now)
socksState.LastActivityAt = now
socksState.ReceivedBytes += uint64(len(packet.Payload))
if packet.Sequence > socksState.LastSequenceSeen {
socksState.LastSequenceSeen = packet.Sequence
responses := []protocol.Packet{response}
for _, buffered := range socksState.drainReadyInboundLocked() {
drainedResponses, err := s.applyOrderedPacketLocked(session, socksState, buffered, now)
if err != nil {
return nil, err
}
responses = append(responses, drainedResponses...)
}
s.log.Debugf(
"<gray>write upstream socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> payload_bytes=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Sequence, len(packet.Payload), session.ClientSessionKey,
)
if err := socksState.writeUpstream(packet.Payload); err != nil {
return responses, nil
case protocol.PacketTypeSOCKSData,
protocol.PacketTypeSOCKSCloseRead,
protocol.PacketTypeSOCKSCloseWrite,
protocol.PacketTypeSOCKSRST:
socksState := s.getOrCreateSOCKSStateLocked(session, packet, now)
readyPackets, duplicate, overflow := socksState.queueInboundPacketLocked(packet, now, s.cfg.MaxReorderBufferPackets)
if duplicate {
return []protocol.Packet{s.duplicateResponsePacket(session, packet)}, nil
}
if overflow {
s.log.Warnf(
"<yellow>write upstream failed socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan> error=<cyan>%v</cyan></yellow>",
packet.SOCKSID, packet.Sequence, session.ClientSessionKey, err,
"<yellow>inbound reorder buffer overflow socks_id=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></yellow>",
packet.SOCKSID, session.ClientSessionKey,
)
response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST)
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
return &response, nil
rst := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST)
rst.SOCKSID = packet.SOCKSID
rst.Sequence = packet.Sequence
_ = socksState.closeUpstream()
socksState.release()
delete(session.SOCKSConnections, packet.SOCKSID)
return []protocol.Packet{rst}, nil
}
response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSDataAck)
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
response.FragmentID = packet.FragmentID
response.TotalFragments = packet.TotalFragments
response.Final = packet.Final
return &response, nil
case protocol.PacketTypeSOCKSCloseRead:
socksState := s.getOrCreateSOCKSStateLocked(session, packet, now)
socksState.LastActivityAt = now
socksState.CloseReadSeen = true
if packet.Sequence > socksState.LastSequenceSeen {
socksState.LastSequenceSeen = packet.Sequence
responses := make([]protocol.Packet, 0, len(readyPackets))
for _, readyPacket := range readyPackets {
appliedResponses, err := s.applyOrderedPacketLocked(session, socksState, readyPacket, now)
if err != nil {
return nil, err
}
responses = append(responses, appliedResponses...)
}
response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSCloseRead)
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
s.log.Debugf(
"<gray>received close_read socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Sequence, session.ClientSessionKey,
)
_ = socksState.closeUpstreamRead()
return &response, nil
case protocol.PacketTypeSOCKSCloseWrite:
socksState := s.getOrCreateSOCKSStateLocked(session, packet, now)
socksState.LastActivityAt = now
socksState.CloseWriteSeen = true
if packet.Sequence > socksState.LastSequenceSeen {
socksState.LastSequenceSeen = packet.Sequence
}
response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSCloseWrite)
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
s.log.Debugf(
"<gray>received close_write socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Sequence, session.ClientSessionKey,
)
_ = socksState.closeUpstreamWrite()
return &response, nil
case protocol.PacketTypeSOCKSRST:
socksState := s.getOrCreateSOCKSStateLocked(session, packet, now)
socksState.LastActivityAt = now
socksState.ResetSeen = true
if packet.Sequence > socksState.LastSequenceSeen {
socksState.LastSequenceSeen = packet.Sequence
}
response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST)
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
s.log.Debugf(
"<gray>received rst socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Sequence, session.ClientSessionKey,
)
_ = socksState.closeUpstream()
socksState.release()
delete(session.SOCKSConnections, packet.SOCKSID)
return &response, nil
return responses, nil
case protocol.PacketTypePing:
s.log.Debugf(
@@ -354,7 +324,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
)
response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypePong)
response.Payload = append([]byte(nil), packet.Payload...)
return &response, nil
return []protocol.Packet{response}, nil
case protocol.PacketTypeSOCKSConnectAck,
protocol.PacketTypeSOCKSConnectFail,
@@ -374,6 +344,8 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
default:
return nil, fmt.Errorf("unsupported packet type: %s", packet.Type)
}
return nil, nil
}
func (s *Server) getOrCreateSession(clientSessionKey string) *ClientSession {
@@ -407,6 +379,9 @@ func (s *Server) getOrCreateSession(clientSessionKey string) *ClientSession {
func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet protocol.Packet, now time.Time) *SOCKSState {
socksState := session.SOCKSConnections[packet.SOCKSID]
if socksState != nil {
if socksState.PendingInbound == nil {
socksState.PendingInbound = make(map[uint64]PendingInboundPacket)
}
return socksState
}
@@ -417,6 +392,7 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot
Target: packet.Target,
LastSequenceSeen: packet.Sequence,
MaxQueueBytes: s.cfg.MaxServerQueueBytes,
PendingInbound: make(map[uint64]PendingInboundPacket),
}
session.SOCKSConnections[packet.SOCKSID] = socksState
s.log.Debugf(
@@ -426,6 +402,120 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot
return socksState
}
func (s *Server) applyOrderedPacketLocked(session *ClientSession, socksState *SOCKSState, packet protocol.Packet, now time.Time) ([]protocol.Packet, error) {
socksState.LastActivityAt = now
if packet.Sequence > socksState.LastSequenceSeen {
socksState.LastSequenceSeen = packet.Sequence
}
switch packet.Type {
case protocol.PacketTypeSOCKSData:
socksState.ReceivedBytes += uint64(len(packet.Payload))
s.log.Debugf(
"<gray>write upstream socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> payload_bytes=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Sequence, len(packet.Payload), session.ClientSessionKey,
)
if err := socksState.writeUpstream(packet.Payload); err != nil {
s.log.Warnf(
"<yellow>write upstream failed socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan> error=<cyan>%v</cyan></yellow>",
packet.SOCKSID, packet.Sequence, session.ClientSessionKey, err,
)
response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST)
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
return []protocol.Packet{response}, nil
}
response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSDataAck)
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
response.FragmentID = packet.FragmentID
response.TotalFragments = packet.TotalFragments
response.Final = packet.Final
return []protocol.Packet{response}, nil
case protocol.PacketTypeSOCKSCloseRead:
socksState.CloseReadSeen = true
response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSCloseRead)
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
s.log.Debugf(
"<gray>received close_read socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Sequence, session.ClientSessionKey,
)
_ = socksState.closeUpstreamRead()
return []protocol.Packet{response}, nil
case protocol.PacketTypeSOCKSCloseWrite:
socksState.CloseWriteSeen = true
response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSCloseWrite)
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
s.log.Debugf(
"<gray>received close_write socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Sequence, session.ClientSessionKey,
)
_ = socksState.closeUpstreamWrite()
return []protocol.Packet{response}, nil
case protocol.PacketTypeSOCKSRST:
socksState.ResetSeen = true
response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST)
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
s.log.Debugf(
"<gray>received rst socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
packet.SOCKSID, packet.Sequence, session.ClientSessionKey,
)
_ = socksState.closeUpstream()
socksState.release()
delete(session.SOCKSConnections, packet.SOCKSID)
return []protocol.Packet{response}, nil
default:
return nil, nil
}
}
func (s *Server) duplicateResponsePacket(session *ClientSession, packet protocol.Packet) protocol.Packet {
responseType := packet.Type
if packet.Type == protocol.PacketTypeSOCKSData {
responseType = protocol.PacketTypeSOCKSDataAck
}
response := protocol.NewPacket(session.ClientSessionKey, responseType)
response.SOCKSID = packet.SOCKSID
response.Sequence = packet.Sequence
response.FragmentID = packet.FragmentID
response.TotalFragments = packet.TotalFragments
response.Final = packet.Final
return response
}
func (s *Server) expireReorderGapsLocked(session *ClientSession, now time.Time) []protocol.Packet {
timeout := time.Duration(s.cfg.ReorderTimeoutMS) * time.Millisecond
if timeout <= 0 {
return nil
}
responses := make([]protocol.Packet, 0)
for socksID, socksState := range session.SOCKSConnections {
if !socksState.hasExpiredInboundGapLocked(now, timeout) {
continue
}
s.log.Warnf(
"<yellow>expired inbound reorder gap client_session_key=<cyan>%s</cyan> socks_id=<cyan>%d</cyan></yellow>",
session.ClientSessionKey, socksID,
)
rst := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST)
rst.SOCKSID = socksID
rst.Sequence = socksState.LastSequenceSeen + 1
responses = append(responses, rst)
_ = socksState.closeUpstream()
socksState.release()
delete(session.SOCKSConnections, socksID)
}
return responses
}
func (s *Server) drainSessionOutboundLocked(session *ClientSession) []protocol.Packet {
packets := make([]protocol.Packet, 0)
if len(session.SOCKSConnections) == 0 {
@@ -523,6 +613,60 @@ func (s *SOCKSState) nextOutboundSequence() uint64 {
return s.OutboundSequence
}
func (s *SOCKSState) expectedInboundSequenceLocked() uint64 {
if s.NextInboundSequence == 0 {
return 1
}
return s.NextInboundSequence
}
func (s *SOCKSState) queueInboundPacketLocked(packet protocol.Packet, now time.Time, maxBuffered int) ([]protocol.Packet, bool, bool) {
expected := s.expectedInboundSequenceLocked()
if packet.Sequence < expected {
return nil, true, false
}
if _, exists := s.PendingInbound[packet.Sequence]; exists {
return nil, true, false
}
if len(s.PendingInbound) >= maxBuffered {
return nil, false, true
}
s.PendingInbound[packet.Sequence] = PendingInboundPacket{
Packet: packet,
QueuedAt: now,
}
if !s.ConnectAcked {
return nil, false, false
}
return s.drainReadyInboundLocked(), false, false
}
func (s *SOCKSState) drainReadyInboundLocked() []protocol.Packet {
expected := s.expectedInboundSequenceLocked()
ready := make([]protocol.Packet, 0)
for {
pending, ok := s.PendingInbound[expected]
if !ok {
break
}
ready = append(ready, pending.Packet)
delete(s.PendingInbound, expected)
expected++
}
s.NextInboundSequence = expected
return ready
}
func (s *SOCKSState) hasExpiredInboundGapLocked(now time.Time, timeout time.Duration) bool {
for _, pending := range s.PendingInbound {
if now.Sub(pending.QueuedAt) >= timeout {
clear(s.PendingInbound)
return true
}
}
return false
}
func (s *SOCKSState) enqueueOutboundData(clientSessionKey string, payload []byte, final bool) bool {
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData)
packet.SOCKSID = s.ID
@@ -572,6 +716,8 @@ func (s *SOCKSState) release() {
s.OutboundQueue = nil
s.QueuedBytes = 0
s.queueMu.Unlock()
clear(s.PendingInbound)
s.NextInboundSequence = 0
s.Target = nil
}