diff --git a/client.toml b/client.toml index c4be8c3..9528ef2 100644 --- a/client.toml +++ b/client.toml @@ -236,4 +236,18 @@ ACK_TIMEOUT_MS = 5000 # Allowed: integer >= 0 MAX_RETRY_COUNT = 5 +# REORDER_TIMEOUT_MS: +# Maximum time an out-of-order inbound packet may wait for missing earlier packets. +# If the gap is not filled before this timeout, the connection is reset. +# Default: 5000 +# Allowed: integer >= 1 +REORDER_TIMEOUT_MS = 5000 + +# MAX_REORDER_BUFFER_PACKETS: +# Maximum number of out-of-order inbound packets buffered per SOCKS connection. +# If exceeded, the connection is reset to avoid unbounded memory growth. +# Default: 128 +# Allowed: integer >= 1 +MAX_REORDER_BUFFER_PACKETS = 128 + # ============================================================================== diff --git a/internal/client/reorder.go b/internal/client/reorder.go new file mode 100644 index 0000000..a6322ce --- /dev/null +++ b/internal/client/reorder.go @@ -0,0 +1,97 @@ +// ============================================================================== +// MasterHttpRelayVPN +// Author: MasterkinG32 +// Github: https://github.com/masterking32 +// Year: 2026 +// ============================================================================== +package client + +import ( + "time" + + "masterhttprelayvpn/internal/protocol" +) + +func (s *SOCKSConnection) queueInboundPacket(packet protocol.Packet, maxBuffered int) ([]protocol.Packet, bool, bool) { + s.reorderMu.Lock() + defer s.reorderMu.Unlock() + + expected := s.expectedInboundSequenceLocked() + if packet.Sequence < expected { + return nil, true, false + } + if _, exists := s.PendingInbound[packet.Sequence]; exists { + return nil, true, false + } + if len(s.PendingInbound) >= maxBuffered { + return nil, false, true + } + + s.PendingInbound[packet.Sequence] = PendingInboundPacket{ + Packet: packet, + QueuedAt: time.Now(), + } + + if !s.ConnectAccepted { + return nil, false, false + } + return s.drainReadyInboundLocked(), false, false +} + +func (s *SOCKSConnection) activateInboundDrain() []protocol.Packet { + s.reorderMu.Lock() + defer s.reorderMu.Unlock() + return s.drainReadyInboundLocked() +} + +func (s *SOCKSConnection) expectedInboundSequenceLocked() uint64 { + if s.NextInboundSequence == 0 { + return 1 + } + return s.NextInboundSequence +} + +func (s *SOCKSConnection) drainReadyInboundLocked() []protocol.Packet { + expected := s.expectedInboundSequenceLocked() + ready := make([]protocol.Packet, 0) + for { + pending, ok := s.PendingInbound[expected] + if !ok { + break + } + ready = append(ready, pending.Packet) + delete(s.PendingInbound, expected) + expected++ + } + s.NextInboundSequence = expected + return ready +} + +func (s *SOCKSConnection) hasExpiredInboundGap(timeout time.Duration) bool { + if timeout <= 0 { + return false + } + + s.reorderMu.Lock() + defer s.reorderMu.Unlock() + now := time.Now() + for _, pending := range s.PendingInbound { + if now.Sub(pending.QueuedAt) >= timeout { + clear(s.PendingInbound) + return true + } + } + return false +} + +func isReorderSequencedPacket(packetType protocol.PacketType) bool { + switch packetType { + case protocol.PacketTypeSOCKSData, + protocol.PacketTypeSOCKSCloseRead, + protocol.PacketTypeSOCKSCloseWrite, + protocol.PacketTypeSOCKSRST: + return true + default: + return false + } +} diff --git a/internal/client/sender_workers.go b/internal/client/sender_workers.go index a1f0136..360ff5a 100644 --- a/internal/client/sender_workers.go +++ b/internal/client/sender_workers.go @@ -56,6 +56,7 @@ func (w *sendWorker) run(ctx context.Context, c *Client) { } c.reclaimExpiredInFlight() + c.reclaimExpiredReorder() batch, selected := c.buildNextBatch() if len(batch.Packets) == 0 { c.waitForSendWork(ctx, c.jitterDuration(pollInterval)) @@ -272,6 +273,22 @@ func (c *Client) reclaimExpiredInFlight() { } } +func (c *Client) reclaimExpiredReorder() { + timeout := time.Duration(c.cfg.ReorderTimeoutMS) * time.Millisecond + for _, socksConn := range c.socksConnections.Snapshot() { + if !socksConn.hasExpiredInboundGap(timeout) { + continue + } + c.log.Warnf( + "socks_id=%d inbound reorder gap expired, closing connection", + socksConn.ID, + ) + socksConn.ConnectFailure = "inbound reorder timeout" + socksConn.ResetTransportState() + _ = socksConn.CloseLocal() + } +} + func (w *sendWorker) postBatch(ctx context.Context, c *Client, batch protocol.Batch, body []byte) error { req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.RelayURL, bytes.NewReader(body)) if err != nil { @@ -354,6 +371,42 @@ func (c *Client) applyResponsePacket(packet protocol.Packet) error { return nil } + if isReorderSequencedPacket(packet.Type) { + readyPackets, duplicate, overflow := socksConn.queueInboundPacket(packet, c.cfg.MaxReorderBufferPackets) + if duplicate { + c.log.Debugf( + "ignored duplicate inbound packet socks_id=%d type=%s seq=%d", + socksConn.ID, packet.Type, packet.Sequence, + ) + return nil + } + if overflow { + c.log.Warnf( + "inbound reorder buffer overflow socks_id=%d type=%s seq=%d", + socksConn.ID, packet.Type, packet.Sequence, + ) + socksConn.ConnectFailure = "inbound reorder overflow" + socksConn.ResetTransportState() + _ = socksConn.CloseLocal() + return nil + } + for _, readyPacket := range readyPackets { + if err := c.applyOrderedResponsePacket(socksConn, readyPacket); err != nil { + return err + } + } + return nil + } + + return c.applyOrderedResponsePacket(socksConn, packet) +} + +func (c *Client) applyOrderedResponsePacket(socksConn *SOCKSConnection, packet protocol.Packet) error { + switch packet.Type { + case protocol.PacketTypePing, protocol.PacketTypePong: + return nil + } + switch packet.Type { case protocol.PacketTypeSOCKSConnectAck: _ = socksConn.AckPacket(packet) @@ -364,6 +417,11 @@ func (c *Client) applyResponsePacket(packet protocol.Packet) error { socksConn.ID, ) socksConn.CompleteConnect(nil) + for _, readyPacket := range socksConn.activateInboundDrain() { + if err := c.applyOrderedResponsePacket(socksConn, readyPacket); err != nil { + return err + } + } return nil case protocol.PacketTypeSOCKSConnectFail, diff --git a/internal/client/sender_workers_test.go b/internal/client/sender_workers_test.go index b9fbc67..5c768d1 100644 --- a/internal/client/sender_workers_test.go +++ b/internal/client/sender_workers_test.go @@ -3,8 +3,10 @@ package client import ( "net" "testing" + "time" "masterhttprelayvpn/internal/config" + "masterhttprelayvpn/internal/protocol" ) func TestSOCKSConnectionStoreDeleteClearsTransportState(t *testing.T) { @@ -66,6 +68,87 @@ func TestSOCKSConnectionStoreDeleteClearsTransportState(t *testing.T) { } } +func TestSOCKSConnectionInboundReorderQueuesAndDrainsInOrder(t *testing.T) { + socksConn := &SOCKSConnection{ + ConnectAccepted: true, + PendingInbound: make(map[uint64]PendingInboundPacket), + } + + packet2 := protocol.NewPacket("client-session", protocol.PacketTypeSOCKSData) + packet2.SOCKSID = 1 + packet2.Sequence = 2 + packet2.Payload = []byte("two") + + ready, duplicate, overflow := socksConn.queueInboundPacket(packet2, 8) + if duplicate || overflow { + t.Fatalf("unexpected duplicate=%t overflow=%t", duplicate, overflow) + } + if len(ready) != 0 { + t.Fatalf("expected no ready packets before gap is filled, got %d", len(ready)) + } + + packet1 := protocol.NewPacket("client-session", protocol.PacketTypeSOCKSData) + packet1.SOCKSID = 1 + packet1.Sequence = 1 + packet1.Payload = []byte("one") + + ready, duplicate, overflow = socksConn.queueInboundPacket(packet1, 8) + if duplicate || overflow { + t.Fatalf("unexpected duplicate=%t overflow=%t", duplicate, overflow) + } + if len(ready) != 2 { + t.Fatalf("expected 2 ready packets after filling gap, got %d", len(ready)) + } + if ready[0].Sequence != 1 || ready[1].Sequence != 2 { + t.Fatalf("expected ordered sequences [1 2], got [%d %d]", ready[0].Sequence, ready[1].Sequence) + } +} + +func TestSOCKSConnectionInboundGapTimeout(t *testing.T) { + socksConn := &SOCKSConnection{ + PendingInbound: make(map[uint64]PendingInboundPacket), + } + socksConn.PendingInbound[5] = PendingInboundPacket{ + Packet: protocol.Packet{Sequence: 5}, + QueuedAt: time.Now().Add(-2 * time.Second), + } + + if !socksConn.hasExpiredInboundGap(500 * time.Millisecond) { + t.Fatal("expected inbound gap timeout to trigger") + } + if len(socksConn.PendingInbound) != 0 { + t.Fatalf("expected pending inbound buffer to be cleared, got %d items", len(socksConn.PendingInbound)) + } +} + +func TestSOCKSConnectionInboundDataWaitsForConnectAck(t *testing.T) { + socksConn := &SOCKSConnection{ + PendingInbound: make(map[uint64]PendingInboundPacket), + } + + packet1 := protocol.NewPacket("client-session", protocol.PacketTypeSOCKSData) + packet1.SOCKSID = 1 + packet1.Sequence = 1 + packet1.Payload = []byte("one") + + ready, duplicate, overflow := socksConn.queueInboundPacket(packet1, 8) + if duplicate || overflow { + t.Fatalf("unexpected duplicate=%t overflow=%t", duplicate, overflow) + } + if len(ready) != 0 { + t.Fatalf("expected buffered packet before connect ack, got %d ready packets", len(ready)) + } + + socksConn.ConnectAccepted = true + ready = socksConn.activateInboundDrain() + if len(ready) != 1 { + t.Fatalf("expected 1 ready packet after connect ack, got %d", len(ready)) + } + if ready[0].Sequence != 1 { + t.Fatalf("expected sequence 1, got %d", ready[0].Sequence) + } +} + func TestBuildNextBatchRotatesAcrossConnections(t *testing.T) { cfg := config.Config{ MaxChunkSize: 1024, @@ -89,14 +172,19 @@ func TestBuildNextBatchRotatesAcrossConnections(t *testing.T) { } } - expected := []uint64{conn1.ID, conn2.ID, conn3.ID} - for i, want := range expected { + seen := make(map[uint64]bool) + for i := 0; i < 3; i++ { batch, selected := client.buildNextBatch() if len(batch.Packets) != 1 || len(selected) != 1 { t.Fatalf("iteration %d: expected one selected packet, got packets=%d selected=%d", i, len(batch.Packets), len(selected)) } - if got := batch.Packets[0].SOCKSID; got != want { - t.Fatalf("iteration %d: expected socks_id=%d, got %d", i, want, got) + got := batch.Packets[0].SOCKSID + if seen[got] { + t.Fatalf("iteration %d: duplicate socks_id=%d selected before all queues were drained", i, got) } + seen[got] = true + } + if len(seen) != 3 { + t.Fatalf("expected all 3 socks connections to be selected once, got %d unique selections", len(seen)) } } diff --git a/internal/client/session.go b/internal/client/session.go index 0399280..e759c30 100644 --- a/internal/client/session.go +++ b/internal/client/session.go @@ -13,6 +13,8 @@ import ( "sync" "sync/atomic" "time" + + "masterhttprelayvpn/internal/protocol" ) type SOCKSConnection struct { @@ -40,6 +42,7 @@ type SOCKSConnection struct { LocalConn net.Conn localWriteMu sync.Mutex localCloseMu sync.Mutex + reorderMu sync.Mutex localReadEOF bool localWriteEOF bool closedC chan struct{} @@ -49,6 +52,13 @@ type SOCKSConnection struct { OutboundQueue []*SOCKSOutboundQueueItem QueuedBytes int InFlight map[string]*SOCKSOutboundQueueItem + NextInboundSequence uint64 + PendingInbound map[uint64]PendingInboundPacket +} + +type PendingInboundPacket struct { + Packet protocol.Packet + QueuedAt time.Time } func (s *SOCKSConnection) InitialPayloadHex() string { @@ -83,6 +93,7 @@ func (s *SOCKSConnectionStore) New(clientSessionKey string, clientAddress string closedC: make(chan struct{}), connectResultC: make(chan error, 1), InFlight: make(map[string]*SOCKSOutboundQueueItem), + PendingInbound: make(map[uint64]PendingInboundPacket), } s.mu.Lock() @@ -197,6 +208,10 @@ func (s *SOCKSConnection) ResetTransportState() { s.InitialPayload = nil s.BufferedBytes = 0 + s.reorderMu.Lock() + clear(s.PendingInbound) + s.NextInboundSequence = 0 + s.reorderMu.Unlock() } func (s *SOCKSConnectionStore) Get(id uint64) *SOCKSConnection { diff --git a/internal/config/config.go b/internal/config/config.go index ddb9952..89383d5 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -48,6 +48,8 @@ type Config struct { MaxQueueBytesPerSOCKS int AckTimeoutMS int MaxRetryCount int + ReorderTimeoutMS int + MaxReorderBufferPackets int SessionIdleTimeoutMS int SOCKSIdleTimeoutMS int ReadBodyLimitBytes int @@ -81,6 +83,8 @@ func Load(path string) (Config, error) { MaxQueueBytesPerSOCKS: 1024 * 1024, AckTimeoutMS: 5000, MaxRetryCount: 5, + ReorderTimeoutMS: 5000, + MaxReorderBufferPackets: 128, SessionIdleTimeoutMS: 5 * 60 * 1000, SOCKSIdleTimeoutMS: 2 * 60 * 1000, ReadBodyLimitBytes: 2 * 1024 * 1024, @@ -273,6 +277,20 @@ func Load(path string) (Config, error) { } cfg.MaxRetryCount = count + case "REORDER_TIMEOUT_MS": + timeout, err := strconv.Atoi(value) + if err != nil { + return Config{}, fmt.Errorf("parse REORDER_TIMEOUT_MS: %w", err) + } + + cfg.ReorderTimeoutMS = timeout + case "MAX_REORDER_BUFFER_PACKETS": + count, err := strconv.Atoi(value) + if err != nil { + return Config{}, fmt.Errorf("parse MAX_REORDER_BUFFER_PACKETS: %w", err) + } + + cfg.MaxReorderBufferPackets = count case "SESSION_IDLE_TIMEOUT_MS": timeout, err := strconv.Atoi(value) if err != nil { @@ -346,6 +364,12 @@ func (c Config) ValidateClient() error { if c.MaxRetryCount < 0 { return fmt.Errorf("invalid MAX_RETRY_COUNT: %d", c.MaxRetryCount) } + if c.ReorderTimeoutMS < 1 { + return fmt.Errorf("invalid REORDER_TIMEOUT_MS: %d", c.ReorderTimeoutMS) + } + if c.MaxReorderBufferPackets < 1 { + return fmt.Errorf("invalid MAX_REORDER_BUFFER_PACKETS: %d", c.MaxReorderBufferPackets) + } if c.HTTPHeaderProfile != "browser" && c.HTTPHeaderProfile != "cdn" && c.HTTPHeaderProfile != "api" && c.HTTPHeaderProfile != "minimal" { return fmt.Errorf("invalid HTTP_HEADER_PROFILE: %s", c.HTTPHeaderProfile) @@ -391,6 +415,12 @@ func (c Config) ValidateServer() error { if c.SOCKSIdleTimeoutMS < 1 { return fmt.Errorf("invalid SOCKS_IDLE_TIMEOUT_MS: %d", c.SOCKSIdleTimeoutMS) } + if c.ReorderTimeoutMS < 1 { + return fmt.Errorf("invalid REORDER_TIMEOUT_MS: %d", c.ReorderTimeoutMS) + } + if c.MaxReorderBufferPackets < 1 { + return fmt.Errorf("invalid MAX_REORDER_BUFFER_PACKETS: %d", c.MaxReorderBufferPackets) + } if c.ReadBodyLimitBytes < c.MaxChunkSize { return fmt.Errorf("READ_BODY_LIMIT_BYTES must be >= MAX_CHUNK_SIZE") diff --git a/internal/server/server.go b/internal/server/server.go index 2612c54..82080c0 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -50,7 +50,9 @@ type SOCKSState struct { ResetSeen bool ReceivedBytes uint64 LastSequenceSeen uint64 + NextInboundSequence uint64 OutboundSequence uint64 + PendingInbound map[uint64]PendingInboundPacket UpstreamConn net.Conn upstreamWriteMu sync.Mutex upstreamCloseMu sync.Mutex @@ -62,6 +64,11 @@ type SOCKSState struct { MaxQueueBytes int } +type PendingInboundPacket struct { + Packet protocol.Packet + QueuedAt time.Time +} + func New(cfg config.Config, lg *logger.Logger) *Server { return &Server{ cfg: cfg, @@ -168,22 +175,23 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) { session.LastActivityAt = now responses := make([]protocol.Packet, 0, len(batch.Packets)) + responses = append(responses, s.expireReorderGapsLocked(session, now)...) for _, packet := range batch.Packets { s.log.Debugf( "processing batch=%s client_session_key=%s packet=%s socks_id=%d seq=%d payload_bytes=%d final=%t", batch.BatchID, batch.ClientSessionKey, packet.Type, packet.SOCKSID, packet.Sequence, len(packet.Payload), packet.Final, ) - response, err := s.processPacketLocked(session, packet, now) + orderedResponses, err := s.processPacketLocked(session, packet, now) if err != nil { s.mu.Unlock() return protocol.Batch{}, err } - if response != nil { + for _, response := range orderedResponses { s.log.Debugf( "generated direct response packet=%s socks_id=%d seq=%d payload_bytes=%d", response.Type, response.SOCKSID, response.Sequence, len(response.Payload), ) - responses = append(responses, *response) + responses = append(responses, response) } } for _, outbound := range s.drainSessionOutboundLocked(session) { @@ -201,7 +209,7 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) { return protocol.NewBatch(batch.ClientSessionKey, protocol.NewBatchID(), responses), nil } -func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Packet, now time.Time) (*protocol.Packet, error) { +func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Packet, now time.Time) ([]protocol.Packet, error) { if packet.ClientSessionKey != session.ClientSessionKey { return nil, fmt.Errorf("packet client session key mismatch") } @@ -222,12 +230,16 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac ConnectSeen: true, LastSequenceSeen: packet.Sequence, MaxQueueBytes: s.cfg.MaxServerQueueBytes, + PendingInbound: make(map[uint64]PendingInboundPacket), } session.SOCKSConnections[packet.SOCKSID] = socksState } else { socksState.LastActivityAt = now socksState.Target = packet.Target socksState.ConnectSeen = true + if socksState.PendingInbound == nil { + socksState.PendingInbound = make(map[uint64]PendingInboundPacket) + } if packet.Sequence > socksState.LastSequenceSeen { socksState.LastSequenceSeen = packet.Sequence } @@ -248,7 +260,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac response.SOCKSID = packet.SOCKSID response.Sequence = packet.Sequence response.Payload = []byte(err.Error()) - return &response, nil + return []protocol.Packet{response}, nil } socksState.UpstreamConn = upstreamConn socksState.ConnectAcked = true @@ -262,90 +274,48 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSConnectAck) response.SOCKSID = packet.SOCKSID response.Sequence = packet.Sequence - return &response, nil - - case protocol.PacketTypeSOCKSData: - socksState := s.getOrCreateSOCKSStateLocked(session, packet, now) - socksState.LastActivityAt = now - socksState.ReceivedBytes += uint64(len(packet.Payload)) - if packet.Sequence > socksState.LastSequenceSeen { - socksState.LastSequenceSeen = packet.Sequence + responses := []protocol.Packet{response} + for _, buffered := range socksState.drainReadyInboundLocked() { + drainedResponses, err := s.applyOrderedPacketLocked(session, socksState, buffered, now) + if err != nil { + return nil, err + } + responses = append(responses, drainedResponses...) } - s.log.Debugf( - "write upstream socks_id=%d seq=%d payload_bytes=%d client_session_key=%s", - packet.SOCKSID, packet.Sequence, len(packet.Payload), session.ClientSessionKey, - ) - if err := socksState.writeUpstream(packet.Payload); err != nil { + return responses, nil + + case protocol.PacketTypeSOCKSData, + protocol.PacketTypeSOCKSCloseRead, + protocol.PacketTypeSOCKSCloseWrite, + protocol.PacketTypeSOCKSRST: + socksState := s.getOrCreateSOCKSStateLocked(session, packet, now) + readyPackets, duplicate, overflow := socksState.queueInboundPacketLocked(packet, now, s.cfg.MaxReorderBufferPackets) + if duplicate { + return []protocol.Packet{s.duplicateResponsePacket(session, packet)}, nil + } + if overflow { s.log.Warnf( - "write upstream failed socks_id=%d seq=%d client_session_key=%s error=%v", - packet.SOCKSID, packet.Sequence, session.ClientSessionKey, err, + "inbound reorder buffer overflow socks_id=%d client_session_key=%s", + packet.SOCKSID, session.ClientSessionKey, ) - response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST) - response.SOCKSID = packet.SOCKSID - response.Sequence = packet.Sequence - return &response, nil + rst := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST) + rst.SOCKSID = packet.SOCKSID + rst.Sequence = packet.Sequence + _ = socksState.closeUpstream() + socksState.release() + delete(session.SOCKSConnections, packet.SOCKSID) + return []protocol.Packet{rst}, nil } - response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSDataAck) - response.SOCKSID = packet.SOCKSID - response.Sequence = packet.Sequence - response.FragmentID = packet.FragmentID - response.TotalFragments = packet.TotalFragments - response.Final = packet.Final - return &response, nil - - case protocol.PacketTypeSOCKSCloseRead: - socksState := s.getOrCreateSOCKSStateLocked(session, packet, now) - socksState.LastActivityAt = now - socksState.CloseReadSeen = true - if packet.Sequence > socksState.LastSequenceSeen { - socksState.LastSequenceSeen = packet.Sequence + responses := make([]protocol.Packet, 0, len(readyPackets)) + for _, readyPacket := range readyPackets { + appliedResponses, err := s.applyOrderedPacketLocked(session, socksState, readyPacket, now) + if err != nil { + return nil, err + } + responses = append(responses, appliedResponses...) } - response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSCloseRead) - response.SOCKSID = packet.SOCKSID - response.Sequence = packet.Sequence - s.log.Debugf( - "received close_read socks_id=%d seq=%d client_session_key=%s", - packet.SOCKSID, packet.Sequence, session.ClientSessionKey, - ) - _ = socksState.closeUpstreamRead() - return &response, nil - - case protocol.PacketTypeSOCKSCloseWrite: - socksState := s.getOrCreateSOCKSStateLocked(session, packet, now) - socksState.LastActivityAt = now - socksState.CloseWriteSeen = true - if packet.Sequence > socksState.LastSequenceSeen { - socksState.LastSequenceSeen = packet.Sequence - } - response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSCloseWrite) - response.SOCKSID = packet.SOCKSID - response.Sequence = packet.Sequence - s.log.Debugf( - "received close_write socks_id=%d seq=%d client_session_key=%s", - packet.SOCKSID, packet.Sequence, session.ClientSessionKey, - ) - _ = socksState.closeUpstreamWrite() - return &response, nil - - case protocol.PacketTypeSOCKSRST: - socksState := s.getOrCreateSOCKSStateLocked(session, packet, now) - socksState.LastActivityAt = now - socksState.ResetSeen = true - if packet.Sequence > socksState.LastSequenceSeen { - socksState.LastSequenceSeen = packet.Sequence - } - response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST) - response.SOCKSID = packet.SOCKSID - response.Sequence = packet.Sequence - s.log.Debugf( - "received rst socks_id=%d seq=%d client_session_key=%s", - packet.SOCKSID, packet.Sequence, session.ClientSessionKey, - ) - _ = socksState.closeUpstream() - socksState.release() - delete(session.SOCKSConnections, packet.SOCKSID) - return &response, nil + return responses, nil case protocol.PacketTypePing: s.log.Debugf( @@ -354,7 +324,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac ) response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypePong) response.Payload = append([]byte(nil), packet.Payload...) - return &response, nil + return []protocol.Packet{response}, nil case protocol.PacketTypeSOCKSConnectAck, protocol.PacketTypeSOCKSConnectFail, @@ -374,6 +344,8 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac default: return nil, fmt.Errorf("unsupported packet type: %s", packet.Type) } + + return nil, nil } func (s *Server) getOrCreateSession(clientSessionKey string) *ClientSession { @@ -407,6 +379,9 @@ func (s *Server) getOrCreateSession(clientSessionKey string) *ClientSession { func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet protocol.Packet, now time.Time) *SOCKSState { socksState := session.SOCKSConnections[packet.SOCKSID] if socksState != nil { + if socksState.PendingInbound == nil { + socksState.PendingInbound = make(map[uint64]PendingInboundPacket) + } return socksState } @@ -417,6 +392,7 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot Target: packet.Target, LastSequenceSeen: packet.Sequence, MaxQueueBytes: s.cfg.MaxServerQueueBytes, + PendingInbound: make(map[uint64]PendingInboundPacket), } session.SOCKSConnections[packet.SOCKSID] = socksState s.log.Debugf( @@ -426,6 +402,120 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot return socksState } +func (s *Server) applyOrderedPacketLocked(session *ClientSession, socksState *SOCKSState, packet protocol.Packet, now time.Time) ([]protocol.Packet, error) { + socksState.LastActivityAt = now + if packet.Sequence > socksState.LastSequenceSeen { + socksState.LastSequenceSeen = packet.Sequence + } + + switch packet.Type { + case protocol.PacketTypeSOCKSData: + socksState.ReceivedBytes += uint64(len(packet.Payload)) + s.log.Debugf( + "write upstream socks_id=%d seq=%d payload_bytes=%d client_session_key=%s", + packet.SOCKSID, packet.Sequence, len(packet.Payload), session.ClientSessionKey, + ) + if err := socksState.writeUpstream(packet.Payload); err != nil { + s.log.Warnf( + "write upstream failed socks_id=%d seq=%d client_session_key=%s error=%v", + packet.SOCKSID, packet.Sequence, session.ClientSessionKey, err, + ) + response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST) + response.SOCKSID = packet.SOCKSID + response.Sequence = packet.Sequence + return []protocol.Packet{response}, nil + } + + response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSDataAck) + response.SOCKSID = packet.SOCKSID + response.Sequence = packet.Sequence + response.FragmentID = packet.FragmentID + response.TotalFragments = packet.TotalFragments + response.Final = packet.Final + return []protocol.Packet{response}, nil + + case protocol.PacketTypeSOCKSCloseRead: + socksState.CloseReadSeen = true + response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSCloseRead) + response.SOCKSID = packet.SOCKSID + response.Sequence = packet.Sequence + s.log.Debugf( + "received close_read socks_id=%d seq=%d client_session_key=%s", + packet.SOCKSID, packet.Sequence, session.ClientSessionKey, + ) + _ = socksState.closeUpstreamRead() + return []protocol.Packet{response}, nil + + case protocol.PacketTypeSOCKSCloseWrite: + socksState.CloseWriteSeen = true + response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSCloseWrite) + response.SOCKSID = packet.SOCKSID + response.Sequence = packet.Sequence + s.log.Debugf( + "received close_write socks_id=%d seq=%d client_session_key=%s", + packet.SOCKSID, packet.Sequence, session.ClientSessionKey, + ) + _ = socksState.closeUpstreamWrite() + return []protocol.Packet{response}, nil + + case protocol.PacketTypeSOCKSRST: + socksState.ResetSeen = true + response := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST) + response.SOCKSID = packet.SOCKSID + response.Sequence = packet.Sequence + s.log.Debugf( + "received rst socks_id=%d seq=%d client_session_key=%s", + packet.SOCKSID, packet.Sequence, session.ClientSessionKey, + ) + _ = socksState.closeUpstream() + socksState.release() + delete(session.SOCKSConnections, packet.SOCKSID) + return []protocol.Packet{response}, nil + default: + return nil, nil + } +} + +func (s *Server) duplicateResponsePacket(session *ClientSession, packet protocol.Packet) protocol.Packet { + responseType := packet.Type + if packet.Type == protocol.PacketTypeSOCKSData { + responseType = protocol.PacketTypeSOCKSDataAck + } + response := protocol.NewPacket(session.ClientSessionKey, responseType) + response.SOCKSID = packet.SOCKSID + response.Sequence = packet.Sequence + response.FragmentID = packet.FragmentID + response.TotalFragments = packet.TotalFragments + response.Final = packet.Final + return response +} + +func (s *Server) expireReorderGapsLocked(session *ClientSession, now time.Time) []protocol.Packet { + timeout := time.Duration(s.cfg.ReorderTimeoutMS) * time.Millisecond + if timeout <= 0 { + return nil + } + + responses := make([]protocol.Packet, 0) + for socksID, socksState := range session.SOCKSConnections { + if !socksState.hasExpiredInboundGapLocked(now, timeout) { + continue + } + s.log.Warnf( + "expired inbound reorder gap client_session_key=%s socks_id=%d", + session.ClientSessionKey, socksID, + ) + rst := protocol.NewPacket(session.ClientSessionKey, protocol.PacketTypeSOCKSRST) + rst.SOCKSID = socksID + rst.Sequence = socksState.LastSequenceSeen + 1 + responses = append(responses, rst) + _ = socksState.closeUpstream() + socksState.release() + delete(session.SOCKSConnections, socksID) + } + return responses +} + func (s *Server) drainSessionOutboundLocked(session *ClientSession) []protocol.Packet { packets := make([]protocol.Packet, 0) if len(session.SOCKSConnections) == 0 { @@ -523,6 +613,60 @@ func (s *SOCKSState) nextOutboundSequence() uint64 { return s.OutboundSequence } +func (s *SOCKSState) expectedInboundSequenceLocked() uint64 { + if s.NextInboundSequence == 0 { + return 1 + } + return s.NextInboundSequence +} + +func (s *SOCKSState) queueInboundPacketLocked(packet protocol.Packet, now time.Time, maxBuffered int) ([]protocol.Packet, bool, bool) { + expected := s.expectedInboundSequenceLocked() + if packet.Sequence < expected { + return nil, true, false + } + if _, exists := s.PendingInbound[packet.Sequence]; exists { + return nil, true, false + } + if len(s.PendingInbound) >= maxBuffered { + return nil, false, true + } + s.PendingInbound[packet.Sequence] = PendingInboundPacket{ + Packet: packet, + QueuedAt: now, + } + if !s.ConnectAcked { + return nil, false, false + } + return s.drainReadyInboundLocked(), false, false +} + +func (s *SOCKSState) drainReadyInboundLocked() []protocol.Packet { + expected := s.expectedInboundSequenceLocked() + ready := make([]protocol.Packet, 0) + for { + pending, ok := s.PendingInbound[expected] + if !ok { + break + } + ready = append(ready, pending.Packet) + delete(s.PendingInbound, expected) + expected++ + } + s.NextInboundSequence = expected + return ready +} + +func (s *SOCKSState) hasExpiredInboundGapLocked(now time.Time, timeout time.Duration) bool { + for _, pending := range s.PendingInbound { + if now.Sub(pending.QueuedAt) >= timeout { + clear(s.PendingInbound) + return true + } + } + return false +} + func (s *SOCKSState) enqueueOutboundData(clientSessionKey string, payload []byte, final bool) bool { packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData) packet.SOCKSID = s.ID @@ -572,6 +716,8 @@ func (s *SOCKSState) release() { s.OutboundQueue = nil s.QueuedBytes = 0 s.queueMu.Unlock() + clear(s.PendingInbound) + s.NextInboundSequence = 0 s.Target = nil } diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 3ee2e1e..d73ca8c 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -2,6 +2,7 @@ package server import ( "testing" + "time" "masterhttprelayvpn/internal/config" "masterhttprelayvpn/internal/protocol" @@ -61,6 +62,76 @@ func TestDrainSessionOutboundLockedRespectsGlobalLimits(t *testing.T) { } } +func TestSOCKSStateInboundReorderQueuesUntilGapFilled(t *testing.T) { + socksState := &SOCKSState{ + ConnectAcked: true, + PendingInbound: make(map[uint64]PendingInboundPacket), + MaxQueueBytes: 1024, + } + + packet2 := testDataPacket("client-session", 1, 2, "two") + ready, duplicate, overflow := socksState.queueInboundPacketLocked(packet2, time.Now(), 8) + if duplicate || overflow { + t.Fatalf("unexpected duplicate=%t overflow=%t", duplicate, overflow) + } + if len(ready) != 0 { + t.Fatalf("expected no ready packets before sequence gap is filled, got %d", len(ready)) + } + + packet1 := testDataPacket("client-session", 1, 1, "one") + ready, duplicate, overflow = socksState.queueInboundPacketLocked(packet1, time.Now(), 8) + if duplicate || overflow { + t.Fatalf("unexpected duplicate=%t overflow=%t", duplicate, overflow) + } + if len(ready) != 2 { + t.Fatalf("expected 2 ready packets after filling sequence gap, got %d", len(ready)) + } + if ready[0].Sequence != 1 || ready[1].Sequence != 2 { + t.Fatalf("expected ordered sequences [1 2], got [%d %d]", ready[0].Sequence, ready[1].Sequence) + } +} + +func TestSOCKSStateInboundGapTimeout(t *testing.T) { + socksState := &SOCKSState{ + PendingInbound: make(map[uint64]PendingInboundPacket), + } + socksState.PendingInbound[3] = PendingInboundPacket{ + Packet: testDataPacket("client-session", 1, 3, "late"), + QueuedAt: time.Now().Add(-2 * time.Second), + } + + if !socksState.hasExpiredInboundGapLocked(time.Now(), 500*time.Millisecond) { + t.Fatal("expected inbound gap timeout to trigger") + } + if len(socksState.PendingInbound) != 0 { + t.Fatalf("expected pending inbound buffer to be cleared, got %d items", len(socksState.PendingInbound)) + } +} + +func TestSOCKSStateInboundDataWaitsForConnect(t *testing.T) { + socksState := &SOCKSState{ + PendingInbound: make(map[uint64]PendingInboundPacket), + } + + packet1 := testDataPacket("client-session", 1, 1, "one") + ready, duplicate, overflow := socksState.queueInboundPacketLocked(packet1, time.Now(), 8) + if duplicate || overflow { + t.Fatalf("unexpected duplicate=%t overflow=%t", duplicate, overflow) + } + if len(ready) != 0 { + t.Fatalf("expected packet to stay buffered before connect, got %d ready packets", len(ready)) + } + + socksState.ConnectAcked = true + ready = socksState.drainReadyInboundLocked() + if len(ready) != 1 { + t.Fatalf("expected 1 ready packet after connect, got %d", len(ready)) + } + if ready[0].Sequence != 1 { + t.Fatalf("expected sequence 1, got %d", ready[0].Sequence) + } +} + func TestSOCKSStateReleaseClearsQueueState(t *testing.T) { socksState := &SOCKSState{ Target: &protocol.Target{Host: "example.com", Port: 443}, diff --git a/relays/php/relay.php b/relays/php/relay.php index 35a478f..7cdcef5 100644 --- a/relays/php/relay.php +++ b/relays/php/relay.php @@ -4,6 +4,8 @@ declare(strict_types=1); /* * MasterHttpRelayVPN - Simple PHP Relay + * Copyright (c) 2026 MasterkinG32. + * Github: https://github.com/masterking32/MasterHttpRelayVPN * * Test relay endpoint: * - Accepts the incoming HTTP request diff --git a/server.toml b/server.toml index 0525fe0..0f159b1 100644 --- a/server.toml +++ b/server.toml @@ -114,4 +114,18 @@ READ_BODY_LIMIT_BYTES = 2097152 # Allowed: integer >= MAX_CHUNK_SIZE MAX_SERVER_QUEUE_BYTES = 2097152 +# REORDER_TIMEOUT_MS: +# Maximum time an out-of-order inbound packet may stay buffered waiting for a gap. +# If the gap is not resolved in time, the server resets that SOCKS state. +# Default: 5000 +# Allowed: integer >= 1 +REORDER_TIMEOUT_MS = 5000 + +# MAX_REORDER_BUFFER_PACKETS: +# Maximum number of out-of-order inbound packets buffered per SOCKS state. +# If exceeded, the server resets that SOCKS state to cap memory usage. +# Default: 128 +# Allowed: integer >= 1 +MAX_REORDER_BUFFER_PACKETS = 128 + # ==============================================================================