Add in-flight ACK tracking and retry handling

This commit is contained in:
Amin.MasterkinG
2026-04-20 20:04:53 +03:30
parent c4f4779ec9
commit 2baf5e8718
5 changed files with 191 additions and 3 deletions
+42 -3
View File
@@ -55,6 +55,7 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
default:
}
c.reclaimExpiredInFlight()
batch, selected := c.buildNextBatch()
if len(batch.Packets) == 0 {
time.Sleep(pollInterval)
@@ -68,6 +69,8 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
continue
}
c.markSelectedInFlight(selected)
body, err := protocol.EncryptBatch(batch, c.cfg.AESEncryptionKey)
if err != nil {
c.log.Errorf("<red>worker=<cyan>%d</cyan> encrypt batch failed: <cyan>%v</cyan></red>", w.id, err)
@@ -164,13 +167,41 @@ func (c *Client) buildPollBatch(connections []*SOCKSConnection) (protocol.Batch,
}
func (c *Client) requeueSelected(selected []dequeuedPacket) {
grouped := make(map[*SOCKSConnection][]string)
for _, entry := range selected {
grouped[entry.socksConn] = append(grouped[entry.socksConn], entry.item.IdentityKey)
}
for socksConn, identityKeys := range grouped {
socksConn.RequeueInFlightByIdentity(identityKeys)
}
}
func (c *Client) markSelectedInFlight(selected []dequeuedPacket) {
grouped := make(map[*SOCKSConnection][]*SOCKSOutboundQueueItem)
for _, entry := range selected {
grouped[entry.socksConn] = append(grouped[entry.socksConn], entry.item)
}
for socksConn, items := range grouped {
socksConn.RequeueFront(items)
socksConn.MarkInFlight(items)
}
}
func (c *Client) reclaimExpiredInFlight() {
ackTimeout := time.Duration(c.cfg.AckTimeoutMS) * time.Millisecond
for _, socksConn := range c.socksConnections.Snapshot() {
requeued, dropped := socksConn.ReclaimExpiredInFlight(ackTimeout, c.cfg.MaxRetryCount)
if requeued > 0 || dropped > 0 {
c.log.Warnf(
"<yellow>socks_id=<cyan>%d</cyan> reclaimed inflight requeued=<cyan>%d</cyan> dropped=<cyan>%d</cyan></yellow>",
socksConn.ID, requeued, dropped,
)
if dropped > 0 {
socksConn.ConnectFailure = "max retry exceeded"
socksConn.CompleteConnect(fmt.Errorf("max retry exceeded"))
_ = socksConn.CloseLocal()
}
}
}
}
@@ -227,7 +258,7 @@ func (c *Client) applyResponseBatch(batch protocol.Batch) error {
func (c *Client) applyResponsePacket(packet protocol.Packet) error {
switch packet.Type {
case protocol.PacketTypePing, protocol.PacketTypePong, protocol.PacketTypeSOCKSDataAck:
case protocol.PacketTypePing, protocol.PacketTypePong:
return nil
}
@@ -238,6 +269,7 @@ func (c *Client) applyResponsePacket(packet protocol.Packet) error {
switch packet.Type {
case protocol.PacketTypeSOCKSConnectAck:
_ = socksConn.AckPacket(packet)
socksConn.ConnectAccepted = true
socksConn.LastActivityAt = time.Now()
socksConn.CompleteConnect(nil)
@@ -257,16 +289,23 @@ func (c *Client) applyResponsePacket(packet protocol.Packet) error {
if len(packet.Payload) > 0 {
message = string(packet.Payload)
}
_ = socksConn.AckPacket(packet)
socksConn.ConnectFailure = message
socksConn.CompleteConnect(fmt.Errorf("%s", message))
_ = socksConn.CloseLocal()
return nil
case protocol.PacketTypeSOCKSDataAck:
_ = socksConn.AckPacket(packet)
socksConn.LastActivityAt = time.Now()
return nil
case protocol.PacketTypeSOCKSData:
socksConn.LastActivityAt = time.Now()
return socksConn.WriteToLocal(packet.Payload)
case protocol.PacketTypeSOCKSCloseRead, protocol.PacketTypeSOCKSCloseWrite, protocol.PacketTypeSOCKSRST:
_ = socksConn.AckPacket(packet)
socksConn.LastActivityAt = time.Now()
return socksConn.CloseLocal()
+2
View File
@@ -42,6 +42,7 @@ type SOCKSConnection struct {
queueMu sync.Mutex
OutboundQueue []*SOCKSOutboundQueueItem
QueuedBytes int
InFlight map[string]*SOCKSOutboundQueueItem
}
func (s *SOCKSConnection) InitialPayloadHex() string {
@@ -74,6 +75,7 @@ func (s *SOCKSConnectionStore) New(clientSessionKey string, clientAddress string
LastActivityAt: now,
ClientAddress: clientAddress,
connectResultC: make(chan error, 1),
InFlight: make(map[string]*SOCKSOutboundQueueItem),
}
s.mu.Lock()
+123
View File
@@ -39,7 +39,9 @@ type SOCKSOutboundQueueItem struct {
IdentityKey string
Packet protocol.Packet
QueuedAt time.Time
SentAt time.Time
PayloadSize int
RetryCount int
}
func (s *SOCKSConnection) EnqueuePacket(packet protocol.Packet) error {
@@ -136,6 +138,127 @@ func (s *SOCKSConnection) RequeueFront(items []*SOCKSOutboundQueueItem) {
s.OutboundQueue = front
}
func (s *SOCKSConnection) MarkInFlight(items []*SOCKSOutboundQueueItem) {
if len(items) == 0 {
return
}
s.queueMu.Lock()
defer s.queueMu.Unlock()
for _, item := range items {
if item == nil {
continue
}
item.SentAt = time.Now()
s.InFlight[item.IdentityKey] = item
}
}
func (s *SOCKSConnection) AckPacket(packet protocol.Packet) bool {
identityKey := protocol.PacketIdentityKey(
packet.ClientSessionKey,
packet.SOCKSID,
ackTargetPacketType(packet.Type),
packet.Sequence,
packet.FragmentID,
)
s.queueMu.Lock()
defer s.queueMu.Unlock()
if _, ok := s.InFlight[identityKey]; ok {
delete(s.InFlight, identityKey)
return true
}
return false
}
func (s *SOCKSConnection) RequeueInFlightByIdentity(identityKeys []string) {
if len(identityKeys) == 0 {
return
}
s.queueMu.Lock()
defer s.queueMu.Unlock()
front := make([]*SOCKSOutboundQueueItem, 0, len(identityKeys)+len(s.OutboundQueue))
for _, identityKey := range identityKeys {
item, ok := s.InFlight[identityKey]
if !ok || item == nil {
continue
}
delete(s.InFlight, identityKey)
item.SentAt = time.Time{}
front = append(front, item)
s.QueuedBytes += item.PayloadSize
}
front = append(front, s.OutboundQueue...)
s.OutboundQueue = front
}
func (s *SOCKSConnection) ReclaimExpiredInFlight(ackTimeout time.Duration, maxRetryCount int) (requeued int, dropped int) {
now := time.Now()
s.queueMu.Lock()
defer s.queueMu.Unlock()
if len(s.InFlight) == 0 {
return 0, 0
}
front := make([]*SOCKSOutboundQueueItem, 0, len(s.InFlight)+len(s.OutboundQueue))
for identityKey, item := range s.InFlight {
if item == nil || item.SentAt.IsZero() || now.Sub(item.SentAt) < ackTimeout {
continue
}
delete(s.InFlight, identityKey)
if item.RetryCount >= maxRetryCount {
dropped++
continue
}
item.RetryCount++
item.SentAt = time.Time{}
front = append(front, item)
s.QueuedBytes += item.PayloadSize
requeued++
}
if len(front) > 0 {
front = append(front, s.OutboundQueue...)
s.OutboundQueue = front
}
return requeued, dropped
}
func ackTargetPacketType(packetType protocol.PacketType) protocol.PacketType {
switch packetType {
case protocol.PacketTypeSOCKSConnectAck,
protocol.PacketTypeSOCKSConnectFail,
protocol.PacketTypeSOCKSRuleSetDenied,
protocol.PacketTypeSOCKSNetworkUnreachable,
protocol.PacketTypeSOCKSHostUnreachable,
protocol.PacketTypeSOCKSConnectionRefused,
protocol.PacketTypeSOCKSTTLExpired,
protocol.PacketTypeSOCKSCommandUnsupported,
protocol.PacketTypeSOCKSAddressTypeUnsupported,
protocol.PacketTypeSOCKSAuthFailed,
protocol.PacketTypeSOCKSUpstreamUnavailable:
return protocol.PacketTypeSOCKSConnect
case protocol.PacketTypeSOCKSDataAck:
return protocol.PacketTypeSOCKSData
case protocol.PacketTypeSOCKSCloseRead:
return protocol.PacketTypeSOCKSCloseRead
case protocol.PacketTypeSOCKSCloseWrite:
return protocol.PacketTypeSOCKSCloseWrite
case protocol.PacketTypeSOCKSRST:
return protocol.PacketTypeSOCKSRST
default:
return packetType
}
}
func splitPayloadChunks(payload []byte, maxChunkSize int) [][]byte {
if len(payload) == 0 || maxChunkSize <= 0 {
return nil