Remove dead initial-payload path from SOCKS connect and align client relay flow

This commit is contained in:
Amin.MasterkinG
2026-04-21 14:04:01 +03:30
parent 22f13fb234
commit 8a50614510
4 changed files with 3 additions and 21 deletions
+1 -5
View File
@@ -45,8 +45,7 @@ func TestSOCKSConnectionStoreDeleteClearsTransportState(t *testing.T) {
socksConn := store.New("client-session", "127.0.0.1:1000", chunkPolicy) socksConn := store.New("client-session", "127.0.0.1:1000", chunkPolicy)
socksConn.LocalConn = localConn socksConn.LocalConn = localConn
socksConn.InitialPayload = []byte("initial-payload") socksConn.BufferedBytes = len("initial-payload")
socksConn.BufferedBytes = len(socksConn.InitialPayload)
if err := socksConn.EnqueuePacket(socksConn.BuildSOCKSDataPacket([]byte("hello"), false)); err != nil { if err := socksConn.EnqueuePacket(socksConn.BuildSOCKSDataPacket([]byte("hello"), false)); err != nil {
t.Fatalf("enqueue first packet: %v", err) t.Fatalf("enqueue first packet: %v", err)
@@ -75,9 +74,6 @@ func TestSOCKSConnectionStoreDeleteClearsTransportState(t *testing.T) {
if len(socksConn.InFlight) != 0 { if len(socksConn.InFlight) != 0 {
t.Fatalf("expected empty inflight map, got %d items", len(socksConn.InFlight)) t.Fatalf("expected empty inflight map, got %d items", len(socksConn.InFlight))
} }
if socksConn.InitialPayload != nil {
t.Fatal("expected initial payload to be cleared")
}
if socksConn.BufferedBytes != 0 { if socksConn.BufferedBytes != 0 {
t.Fatalf("expected buffered bytes to be reset, got %d", socksConn.BufferedBytes) t.Fatalf("expected buffered bytes to be reset, got %d", socksConn.BufferedBytes)
} }
-10
View File
@@ -8,7 +8,6 @@ package client
import ( import (
"context" "context"
"encoding/hex"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -27,7 +26,6 @@ type SOCKSConnection struct {
TargetHost string TargetHost string
TargetPort uint16 TargetPort uint16
TargetAddressType byte TargetAddressType byte
InitialPayload []byte
BufferedBytes int BufferedBytes int
NextSequence uint64 NextSequence uint64
SOCKSAuthMethod byte SOCKSAuthMethod byte
@@ -61,13 +59,6 @@ type PendingInboundPacket struct {
QueuedAt time.Time QueuedAt time.Time
} }
func (s *SOCKSConnection) InitialPayloadHex() string {
if len(s.InitialPayload) == 0 {
return ""
}
return hex.EncodeToString(s.InitialPayload)
}
type SOCKSConnectionStore struct { type SOCKSConnectionStore struct {
nextID atomic.Uint64 nextID atomic.Uint64
mu sync.RWMutex mu sync.RWMutex
@@ -206,7 +197,6 @@ func (s *SOCKSConnection) ResetTransportState() {
clear(s.InFlight) clear(s.InFlight)
s.queueMu.Unlock() s.queueMu.Unlock()
s.InitialPayload = nil
s.BufferedBytes = 0 s.BufferedBytes = 0
s.reorderMu.Lock() s.reorderMu.Lock()
clear(s.PendingInbound) clear(s.PendingInbound)
+2 -3
View File
@@ -112,7 +112,7 @@ func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn, socksConn *SOC
return err return err
} }
return c.captureInitialPayload(ctx, conn, socksConn) return c.relayLocalPayload(ctx, conn, socksConn)
} }
func (c *Client) negotiateAuth(conn net.Conn, socksConn *SOCKSConnection) (byte, error) { func (c *Client) negotiateAuth(conn net.Conn, socksConn *SOCKSConnection) (byte, error) {
@@ -258,7 +258,7 @@ func writeSocksReply(conn net.Conn, reply byte) error {
return err return err
} }
func (c *Client) captureInitialPayload(ctx context.Context, conn net.Conn, socksConn *SOCKSConnection) error { func (c *Client) relayLocalPayload(ctx context.Context, conn net.Conn, socksConn *SOCKSConnection) error {
peekTimeout := 2 * time.Second peekTimeout := 2 * time.Second
idleTimeout := 2 * time.Second idleTimeout := 2 * time.Second
buf := make([]byte, 32*1024) buf := make([]byte, 32*1024)
@@ -269,7 +269,6 @@ func (c *Client) captureInitialPayload(ctx context.Context, conn net.Conn, socks
n, err := conn.Read(buf) n, err := conn.Read(buf)
if err == nil && n > 0 { if err == nil && n > 0 {
socksConn.InitialPayload = append([]byte(nil), buf[:n]...)
socksConn.BufferedBytes += n socksConn.BufferedBytes += n
socksConn.LastActivityAt = time.Now() socksConn.LastActivityAt = time.Now()
enqueued, enqueueErr := socksConn.EnqueuePayloadChunks(buf[:n], false) enqueued, enqueueErr := socksConn.EnqueuePayloadChunks(buf[:n], false)
-3
View File
@@ -21,9 +21,6 @@ func (s *SOCKSConnection) BuildSOCKSConnectPacket() protocol.Packet {
Port: s.TargetPort, Port: s.TargetPort,
AddressType: s.TargetAddressType, AddressType: s.TargetAddressType,
} }
if len(s.InitialPayload) > 0 {
packet.Payload = append([]byte(nil), s.InitialPayload...)
}
return packet return packet
} }