Deduplicate packet reorder logic and consolidate client SOCKS session handling

This commit is contained in:
Amin.MasterkinG
2026-04-21 15:08:10 +03:30
parent 8a50614510
commit bf5c0ef06e
9 changed files with 497 additions and 573 deletions
+9 -59
View File
@@ -55,7 +55,7 @@ type SOCKSState struct {
LastSequenceSeen uint64
NextInboundSequence uint64
OutboundSequence uint64
PendingInbound map[uint64][]PendingInboundPacket
PendingInbound map[uint64][]protocol.PendingPacket
UpstreamConn net.Conn
activityMu sync.RWMutex
upstreamStateMu sync.RWMutex
@@ -68,11 +68,6 @@ type SOCKSState struct {
MaxQueueBytes int
}
type PendingInboundPacket struct {
Packet protocol.Packet
QueuedAt time.Time
}
func New(cfg config.Config, lg *logger.Logger) *Server {
return &Server{
cfg: cfg,
@@ -235,7 +230,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
ConnectSeen: true,
LastSequenceSeen: packet.Sequence,
MaxQueueBytes: s.cfg.MaxServerQueueBytes,
PendingInbound: make(map[uint64][]PendingInboundPacket),
PendingInbound: make(map[uint64][]protocol.PendingPacket),
}
session.SOCKSConnections[packet.SOCKSID] = socksState
} else {
@@ -243,7 +238,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
socksState.Target = packet.Target
socksState.ConnectSeen = true
if socksState.PendingInbound == nil {
socksState.PendingInbound = make(map[uint64][]PendingInboundPacket)
socksState.PendingInbound = make(map[uint64][]protocol.PendingPacket)
}
if packet.Sequence > socksState.LastSequenceSeen {
socksState.LastSequenceSeen = packet.Sequence
@@ -387,7 +382,7 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot
socksState := session.SOCKSConnections[packet.SOCKSID]
if socksState != nil {
if socksState.PendingInbound == nil {
socksState.PendingInbound = make(map[uint64][]PendingInboundPacket)
socksState.PendingInbound = make(map[uint64][]protocol.PendingPacket)
}
return socksState
}
@@ -399,7 +394,7 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot
Target: packet.Target,
LastSequenceSeen: packet.Sequence,
MaxQueueBytes: s.cfg.MaxServerQueueBytes,
PendingInbound: make(map[uint64][]PendingInboundPacket),
PendingInbound: make(map[uint64][]protocol.PendingPacket),
}
session.SOCKSConnections[packet.SOCKSID] = socksState
s.log.Debugf(
@@ -633,13 +628,13 @@ func (s *SOCKSState) queueInboundPacketLocked(packet protocol.Packet, now time.T
return nil, true, false
}
pendingForSequence := s.PendingInbound[packet.Sequence]
if containsPendingInboundPacketLocked(pendingForSequence, packet) {
if protocol.ContainsPendingPacket(pendingForSequence, packet) {
return nil, true, false
}
if bufferedInboundPacketCountLocked(s.PendingInbound) >= maxBuffered {
if protocol.BufferedPendingPacketCount(s.PendingInbound) >= maxBuffered {
return nil, false, true
}
s.PendingInbound[packet.Sequence] = append(s.PendingInbound[packet.Sequence], PendingInboundPacket{
s.PendingInbound[packet.Sequence] = append(s.PendingInbound[packet.Sequence], protocol.PendingPacket{
Packet: packet,
QueuedAt: now,
})
@@ -657,7 +652,7 @@ func (s *SOCKSState) drainReadyInboundLocked() []protocol.Packet {
if !ok || len(pendingPackets) == 0 {
break
}
sortPendingInboundPacketsLocked(pendingPackets)
protocol.SortPendingPackets(pendingPackets)
for _, pending := range pendingPackets {
ready = append(ready, pending.Packet)
}
@@ -680,51 +675,6 @@ func (s *SOCKSState) hasExpiredInboundGapLocked(now time.Time, timeout time.Dura
return false
}
func containsPendingInboundPacketLocked(pendingPackets []PendingInboundPacket, packet protocol.Packet) bool {
for _, pending := range pendingPackets {
if pending.Packet.Type == packet.Type &&
pending.Packet.FragmentID == packet.FragmentID &&
pending.Packet.TotalFragments == packet.TotalFragments {
return true
}
}
return false
}
func bufferedInboundPacketCountLocked(pending map[uint64][]PendingInboundPacket) int {
total := 0
for _, pendingPackets := range pending {
total += len(pendingPackets)
}
return total
}
func sortPendingInboundPacketsLocked(pendingPackets []PendingInboundPacket) {
for i := 1; i < len(pendingPackets); i++ {
current := pendingPackets[i]
j := i - 1
for ; j >= 0 && inboundPacketSortOrderLocked(current.Packet.Type) < inboundPacketSortOrderLocked(pendingPackets[j].Packet.Type); j-- {
pendingPackets[j+1] = pendingPackets[j]
}
pendingPackets[j+1] = current
}
}
func inboundPacketSortOrderLocked(packetType protocol.PacketType) int {
switch packetType {
case protocol.PacketTypeSOCKSData:
return 0
case protocol.PacketTypeSOCKSCloseRead:
return 1
case protocol.PacketTypeSOCKSCloseWrite:
return 2
case protocol.PacketTypeSOCKSRST:
return 3
default:
return 4
}
}
func (s *SOCKSState) enqueueOutboundData(clientSessionKey string, payload []byte, final bool) bool {
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData)
packet.SOCKSID = s.ID
+5 -5
View File
@@ -69,7 +69,7 @@ func TestDrainSessionOutboundLockedRespectsGlobalLimits(t *testing.T) {
func TestSOCKSStateInboundReorderQueuesUntilGapFilled(t *testing.T) {
socksState := &SOCKSState{
ConnectAcked: true,
PendingInbound: make(map[uint64][]PendingInboundPacket),
PendingInbound: make(map[uint64][]protocol.PendingPacket),
MaxQueueBytes: 1024,
}
@@ -97,9 +97,9 @@ func TestSOCKSStateInboundReorderQueuesUntilGapFilled(t *testing.T) {
func TestSOCKSStateInboundGapTimeout(t *testing.T) {
socksState := &SOCKSState{
PendingInbound: make(map[uint64][]PendingInboundPacket),
PendingInbound: make(map[uint64][]protocol.PendingPacket),
}
socksState.PendingInbound[3] = []PendingInboundPacket{{
socksState.PendingInbound[3] = []protocol.PendingPacket{{
Packet: testDataPacket("client-session", 1, 3, "late"),
QueuedAt: time.Now().Add(-2 * time.Second),
}}
@@ -114,7 +114,7 @@ func TestSOCKSStateInboundGapTimeout(t *testing.T) {
func TestSOCKSStateInboundDataWaitsForConnect(t *testing.T) {
socksState := &SOCKSState{
PendingInbound: make(map[uint64][]PendingInboundPacket),
PendingInbound: make(map[uint64][]protocol.PendingPacket),
}
packet1 := testDataPacket("client-session", 1, 1, "one")
@@ -139,7 +139,7 @@ func TestSOCKSStateInboundDataWaitsForConnect(t *testing.T) {
func TestSOCKSStateInboundReorderAllowsMultiplePacketTypesPerSequence(t *testing.T) {
socksState := &SOCKSState{
ConnectAcked: true,
PendingInbound: make(map[uint64][]PendingInboundPacket),
PendingInbound: make(map[uint64][]protocol.PendingPacket),
}
closeWrite := protocol.NewPacket("client-session", protocol.PacketTypeSOCKSCloseWrite)