mirror of
https://github.com/masterking32/MasterHttpRelayVPN.git
synced 2026-05-18 06:34:40 +03:00
Changing ping system.
This commit is contained in:
@@ -32,9 +32,13 @@ type Client struct {
|
||||
conns map[net.Conn]struct{}
|
||||
workCh chan struct{}
|
||||
|
||||
lastPollUnixMS atomic.Int64
|
||||
activeBatches atomic.Int64
|
||||
batchCursor atomic.Uint64
|
||||
lastMeaningfulActivityUnixMS atomic.Int64
|
||||
lastPingMeaningfulSnapshotMS atomic.Int64
|
||||
nextPingDueUnixMS atomic.Int64
|
||||
activeBatches atomic.Int64
|
||||
pingInFlight atomic.Int64
|
||||
idlePongStreak atomic.Int64
|
||||
batchCursor atomic.Uint64
|
||||
}
|
||||
|
||||
func New(cfg config.Config, lg *logger.Logger) *Client {
|
||||
@@ -131,6 +135,57 @@ func (c *Client) signalSendWork() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) noteMeaningfulActivity(now time.Time) {
|
||||
c.lastMeaningfulActivityUnixMS.Store(now.UnixMilli())
|
||||
c.idlePongStreak.Store(0)
|
||||
c.nextPingDueUnixMS.Store(now.Add(time.Duration(c.cfg.IdlePollIntervalMS) * time.Millisecond).UnixMilli())
|
||||
}
|
||||
|
||||
func (c *Client) tryBeginPing(now time.Time) bool {
|
||||
if !c.pingInFlight.CompareAndSwap(0, 1) {
|
||||
return false
|
||||
}
|
||||
|
||||
c.lastPingMeaningfulSnapshotMS.Store(c.lastMeaningfulActivityUnixMS.Load())
|
||||
return true
|
||||
}
|
||||
|
||||
func (c *Client) completePingWithPong() {
|
||||
c.pingInFlight.Store(0)
|
||||
now := time.Now()
|
||||
|
||||
if c.lastMeaningfulActivityUnixMS.Load() == c.lastPingMeaningfulSnapshotMS.Load() {
|
||||
lastMeaningfulAt := time.UnixMilli(c.lastMeaningfulActivityUnixMS.Load())
|
||||
idleFor := now.Sub(lastMeaningfulAt)
|
||||
warmThreshold := time.Duration(c.cfg.PingWarmThresholdMS) * time.Millisecond
|
||||
if idleFor < warmThreshold {
|
||||
c.idlePongStreak.Store(0)
|
||||
c.nextPingDueUnixMS.Store(now.Add(time.Duration(c.cfg.IdlePollIntervalMS) * time.Millisecond).UnixMilli())
|
||||
return
|
||||
}
|
||||
|
||||
streak := c.idlePongStreak.Add(1)
|
||||
c.nextPingDueUnixMS.Store(now.Add(c.idleIntervalForStreak(streak)).UnixMilli())
|
||||
return
|
||||
}
|
||||
|
||||
c.idlePongStreak.Store(0)
|
||||
c.nextPingDueUnixMS.Store(now.Add(time.Duration(c.cfg.IdlePollIntervalMS) * time.Millisecond).UnixMilli())
|
||||
}
|
||||
|
||||
func (c *Client) failPing() {
|
||||
c.pingInFlight.Store(0)
|
||||
c.nextPingDueUnixMS.Store(time.Now().Add(time.Duration(c.cfg.IdlePollIntervalMS) * time.Millisecond).UnixMilli())
|
||||
}
|
||||
|
||||
func (c *Client) idleIntervalForStreak(streak int64) time.Duration {
|
||||
interval := c.cfg.PingBackoffBaseMS + int(streak)*c.cfg.PingBackoffStepMS
|
||||
if interval > c.cfg.PingMaxIntervalMS {
|
||||
interval = c.cfg.PingMaxIntervalMS
|
||||
}
|
||||
return time.Duration(interval) * time.Millisecond
|
||||
}
|
||||
|
||||
func generateClientSessionKey() string {
|
||||
now := time.Now().UTC().Format("20060102T150405.000000000Z")
|
||||
random := make([]byte, 16)
|
||||
|
||||
+61
-10
@@ -20,17 +20,18 @@ func (s *SOCKSConnection) queueInboundPacket(packet protocol.Packet, maxBuffered
|
||||
if packet.Sequence < expected {
|
||||
return nil, true, false
|
||||
}
|
||||
if _, exists := s.PendingInbound[packet.Sequence]; exists {
|
||||
pendingForSequence := s.PendingInbound[packet.Sequence]
|
||||
if containsPendingInboundPacket(pendingForSequence, packet) {
|
||||
return nil, true, false
|
||||
}
|
||||
if len(s.PendingInbound) >= maxBuffered {
|
||||
if bufferedInboundPacketCount(s.PendingInbound) >= maxBuffered {
|
||||
return nil, false, true
|
||||
}
|
||||
|
||||
s.PendingInbound[packet.Sequence] = PendingInboundPacket{
|
||||
s.PendingInbound[packet.Sequence] = append(s.PendingInbound[packet.Sequence], PendingInboundPacket{
|
||||
Packet: packet,
|
||||
QueuedAt: time.Now(),
|
||||
}
|
||||
})
|
||||
|
||||
if !s.ConnectAccepted {
|
||||
return nil, false, false
|
||||
@@ -55,11 +56,14 @@ func (s *SOCKSConnection) drainReadyInboundLocked() []protocol.Packet {
|
||||
expected := s.expectedInboundSequenceLocked()
|
||||
ready := make([]protocol.Packet, 0)
|
||||
for {
|
||||
pending, ok := s.PendingInbound[expected]
|
||||
if !ok {
|
||||
pendingPackets, ok := s.PendingInbound[expected]
|
||||
if !ok || len(pendingPackets) == 0 {
|
||||
break
|
||||
}
|
||||
ready = append(ready, pending.Packet)
|
||||
sortPendingInboundPackets(pendingPackets)
|
||||
for _, pending := range pendingPackets {
|
||||
ready = append(ready, pending.Packet)
|
||||
}
|
||||
delete(s.PendingInbound, expected)
|
||||
expected++
|
||||
}
|
||||
@@ -75,15 +79,62 @@ func (s *SOCKSConnection) hasExpiredInboundGap(timeout time.Duration) bool {
|
||||
s.reorderMu.Lock()
|
||||
defer s.reorderMu.Unlock()
|
||||
now := time.Now()
|
||||
for _, pending := range s.PendingInbound {
|
||||
if now.Sub(pending.QueuedAt) >= timeout {
|
||||
clear(s.PendingInbound)
|
||||
for _, pendingPackets := range s.PendingInbound {
|
||||
for _, pending := range pendingPackets {
|
||||
if now.Sub(pending.QueuedAt) >= timeout {
|
||||
clear(s.PendingInbound)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsPendingInboundPacket(pendingPackets []PendingInboundPacket, packet protocol.Packet) bool {
|
||||
for _, pending := range pendingPackets {
|
||||
if pending.Packet.Type == packet.Type &&
|
||||
pending.Packet.FragmentID == packet.FragmentID &&
|
||||
pending.Packet.TotalFragments == packet.TotalFragments {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func bufferedInboundPacketCount(pending map[uint64][]PendingInboundPacket) int {
|
||||
total := 0
|
||||
for _, pendingPackets := range pending {
|
||||
total += len(pendingPackets)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func sortPendingInboundPackets(pendingPackets []PendingInboundPacket) {
|
||||
for i := 1; i < len(pendingPackets); i++ {
|
||||
current := pendingPackets[i]
|
||||
j := i - 1
|
||||
for ; j >= 0 && inboundPacketSortOrder(current.Packet.Type) < inboundPacketSortOrder(pendingPackets[j].Packet.Type); j-- {
|
||||
pendingPackets[j+1] = pendingPackets[j]
|
||||
}
|
||||
pendingPackets[j+1] = current
|
||||
}
|
||||
}
|
||||
|
||||
func inboundPacketSortOrder(packetType protocol.PacketType) int {
|
||||
switch packetType {
|
||||
case protocol.PacketTypeSOCKSData:
|
||||
return 0
|
||||
case protocol.PacketTypeSOCKSCloseRead:
|
||||
return 1
|
||||
case protocol.PacketTypeSOCKSCloseWrite:
|
||||
return 2
|
||||
case protocol.PacketTypeSOCKSRST:
|
||||
return 3
|
||||
default:
|
||||
return 4
|
||||
}
|
||||
}
|
||||
|
||||
func isReorderSequencedPacket(packetType protocol.PacketType) bool {
|
||||
switch packetType {
|
||||
case protocol.PacketTypeSOCKSData,
|
||||
|
||||
@@ -9,6 +9,8 @@ package client
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -74,6 +76,9 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
|
||||
|
||||
if err := batch.Validate(); err != nil {
|
||||
c.log.Errorf("<red>worker=<cyan>%d</cyan> invalid batch: <cyan>%v</cyan></red>", w.id, err)
|
||||
if isPingOnlyBatch(batch) {
|
||||
c.failPing()
|
||||
}
|
||||
c.requeueSelected(selected)
|
||||
c.releaseBatchSlot()
|
||||
c.waitForSendWork(ctx, c.jitterDuration(waitInterval))
|
||||
@@ -85,6 +90,9 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
|
||||
body, err := protocol.EncryptBatch(batch, c.cfg.AESEncryptionKey)
|
||||
if err != nil {
|
||||
c.log.Errorf("<red>worker=<cyan>%d</cyan> encrypt batch failed: <cyan>%v</cyan></red>", w.id, err)
|
||||
if isPingOnlyBatch(batch) {
|
||||
c.failPing()
|
||||
}
|
||||
c.requeueSelected(selected)
|
||||
c.releaseBatchSlot()
|
||||
c.waitForSendWork(ctx, c.jitterDuration(waitInterval))
|
||||
@@ -104,6 +112,9 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
|
||||
"<green>worker=<cyan>%d</cyan> sent batch=<cyan>%s</cyan> packets=<cyan>%d</cyan> bytes=<cyan>%d</cyan></green>",
|
||||
w.id, batch.BatchID, len(batch.Packets), len(body),
|
||||
)
|
||||
if !isPingOnlyBatch(batch) {
|
||||
c.noteMeaningfulActivity(time.Now())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -119,15 +130,12 @@ func (c *Client) waitForSendWork(ctx context.Context, interval time.Duration) {
|
||||
}
|
||||
|
||||
func (c *Client) buildNextBatch(connections []*SOCKSConnection, totalQueuedBytes int) (protocol.Batch, []dequeuedPacket) {
|
||||
if len(connections) == 0 {
|
||||
return protocol.Batch{}, nil
|
||||
}
|
||||
|
||||
sort.Slice(connections, func(i, j int) bool {
|
||||
return connections[i].ID < connections[j].ID
|
||||
})
|
||||
|
||||
start := 0
|
||||
if len(connections) > 0 {
|
||||
sort.Slice(connections, func(i, j int) bool {
|
||||
return connections[i].ID < connections[j].ID
|
||||
})
|
||||
}
|
||||
if len(connections) > 1 {
|
||||
rotationEvery := c.cfg.MuxRotateEveryBatches
|
||||
if rotationEvery < 1 {
|
||||
@@ -195,28 +203,45 @@ func (c *Client) buildNextBatch(connections []*SOCKSConnection, totalQueuedBytes
|
||||
}
|
||||
|
||||
func (c *Client) buildPollBatch(connections []*SOCKSConnection, totalQueuedBytes int) (protocol.Batch, bool) {
|
||||
if len(connections) == 0 {
|
||||
return protocol.Batch{}, false
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
nowUnixMS := now.UnixMilli()
|
||||
lastUnixMS := c.lastPollUnixMS.Load()
|
||||
minInterval := c.jitterDuration(c.effectiveIdlePollInterval(totalQueuedBytes))
|
||||
if lastUnixMS > 0 && nowUnixMS-lastUnixMS < minInterval.Milliseconds() {
|
||||
if !c.shouldSendPing(connections, totalQueuedBytes, now) {
|
||||
return protocol.Batch{}, false
|
||||
}
|
||||
|
||||
if !c.lastPollUnixMS.CompareAndSwap(lastUnixMS, nowUnixMS) {
|
||||
if !c.tryBeginPing(now) {
|
||||
return protocol.Batch{}, false
|
||||
}
|
||||
|
||||
packet := protocol.NewPacket(c.clientSessionKey, protocol.PacketTypePing)
|
||||
packet.Payload = []byte("poll")
|
||||
packet.Payload = buildPingPayload(now)
|
||||
batch := protocol.NewBatch(c.clientSessionKey, protocol.NewBatchID(), []protocol.Packet{packet})
|
||||
return batch, true
|
||||
}
|
||||
|
||||
func (c *Client) shouldSendPing(connections []*SOCKSConnection, totalQueuedBytes int, now time.Time) bool {
|
||||
if totalQueuedBytes > 0 {
|
||||
return false
|
||||
}
|
||||
if hasQueuedPackets(connections) {
|
||||
return false
|
||||
}
|
||||
if c.pingInFlight.Load() > 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
lastMeaningfulUnixMS := c.lastMeaningfulActivityUnixMS.Load()
|
||||
sessionActive := len(connections) > 0 || lastMeaningfulUnixMS > 0
|
||||
if !sessionActive {
|
||||
return false
|
||||
}
|
||||
|
||||
nextDueUnixMS := c.nextPingDueUnixMS.Load()
|
||||
if nextDueUnixMS <= 0 {
|
||||
return lastMeaningfulUnixMS > 0
|
||||
}
|
||||
|
||||
return now.UnixMilli() >= nextDueUnixMS
|
||||
}
|
||||
|
||||
func (c *Client) effectiveBatchLimits(totalQueuedBytes int) (int, int) {
|
||||
maxPackets := c.cfg.MaxPacketsPerBatch
|
||||
maxBatchBytes := c.cfg.MaxBatchBytes
|
||||
@@ -260,16 +285,6 @@ func (c *Client) effectiveWaitInterval(totalQueuedBytes int) time.Duration {
|
||||
return interval
|
||||
}
|
||||
|
||||
func (c *Client) effectiveIdlePollInterval(totalQueuedBytes int) time.Duration {
|
||||
interval := time.Duration(c.cfg.IdlePollIntervalMS) * time.Millisecond
|
||||
if totalQueuedBytes >= c.cfg.MuxBurstThresholdBytes {
|
||||
if burst := interval / 2; burst >= time.Duration(c.cfg.WorkerPollIntervalMS)*time.Millisecond {
|
||||
return burst
|
||||
}
|
||||
}
|
||||
return interval
|
||||
}
|
||||
|
||||
func (c *Client) effectiveConcurrentBatches(totalQueuedBytes int) int {
|
||||
if totalQueuedBytes >= c.cfg.MuxBurstThresholdBytes {
|
||||
return c.cfg.MaxConcurrentBatches
|
||||
@@ -315,6 +330,29 @@ func queuedBytesAcross(connections []*SOCKSConnection) int {
|
||||
return total
|
||||
}
|
||||
|
||||
func hasQueuedPackets(connections []*SOCKSConnection) bool {
|
||||
for _, socksConn := range connections {
|
||||
queueItems, _ := socksConn.QueueSnapshot()
|
||||
if queueItems > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func buildPingPayload(now time.Time) []byte {
|
||||
random := make([]byte, 8)
|
||||
if _, err := rand.Read(random); err != nil {
|
||||
return []byte(fmt.Sprintf("ping:%d:fallback", now.UnixMilli()))
|
||||
}
|
||||
|
||||
return []byte(fmt.Sprintf("ping:%d:%s", now.UnixMilli(), hex.EncodeToString(random)))
|
||||
}
|
||||
|
||||
func isPingOnlyBatch(batch protocol.Batch) bool {
|
||||
return len(batch.Packets) == 1 && batch.Packets[0].Type == protocol.PacketTypePing
|
||||
}
|
||||
|
||||
func (c *Client) jitterDuration(base time.Duration) time.Duration {
|
||||
if base <= 0 || c.cfg.HTTPTimingJitterMS <= 0 {
|
||||
return base
|
||||
@@ -391,8 +429,12 @@ func (c *Client) reclaimExpiredReorder() {
|
||||
}
|
||||
|
||||
func (w *sendWorker) postBatch(ctx context.Context, c *Client, batch protocol.Batch, body []byte) error {
|
||||
pingOnly := isPingOnlyBatch(batch)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.cfg.RelayURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
if pingOnly {
|
||||
c.failPing()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -404,16 +446,25 @@ func (w *sendWorker) postBatch(ctx context.Context, c *Client, batch protocol.Ba
|
||||
|
||||
resp, err := w.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if pingOnly {
|
||||
c.failPing()
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
if pingOnly {
|
||||
c.failPing()
|
||||
}
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
|
||||
return fmt.Errorf("unexpected status %d: %s", resp.StatusCode, string(respBody))
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusNoContent {
|
||||
if pingOnly {
|
||||
c.failPing()
|
||||
}
|
||||
c.log.Debugf(
|
||||
"<gray>worker=<cyan>%d</cyan> batch=<cyan>%s</cyan> got no-content response</gray>",
|
||||
w.id, batch.BatchID,
|
||||
@@ -423,15 +474,24 @@ func (w *sendWorker) postBatch(ctx context.Context, c *Client, batch protocol.Ba
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
if pingOnly {
|
||||
c.failPing()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
if len(respBody) == 0 {
|
||||
if pingOnly {
|
||||
c.failPing()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
responseBatch, err := protocol.DecryptBatch(respBody, c.cfg.AESEncryptionKey)
|
||||
if err != nil {
|
||||
if pingOnly {
|
||||
c.failPing()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -449,6 +509,13 @@ func (w *sendWorker) postBatch(ctx context.Context, c *Client, batch protocol.Ba
|
||||
|
||||
func (c *Client) applyResponseBatch(batch protocol.Batch) error {
|
||||
for _, packet := range batch.Packets {
|
||||
if packet.Type == protocol.PacketTypePong {
|
||||
c.completePingWithPong()
|
||||
}
|
||||
if packet.Type != protocol.PacketTypePing && packet.Type != protocol.PacketTypePong {
|
||||
c.noteMeaningfulActivity(time.Now())
|
||||
}
|
||||
|
||||
c.log.Debugf(
|
||||
"<gray>apply response packet=<cyan>%s</cyan> socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> payload_bytes=<cyan>%d</cyan> final=<cyan>%t</cyan></gray>",
|
||||
packet.Type, packet.SOCKSID, packet.Sequence, len(packet.Payload), packet.Final,
|
||||
|
||||
@@ -9,6 +9,27 @@ import (
|
||||
"masterhttprelayvpn/internal/protocol"
|
||||
)
|
||||
|
||||
func testClientConfig() config.Config {
|
||||
return config.Config{
|
||||
MaxChunkSize: 1024,
|
||||
MaxPacketsPerBatch: 4,
|
||||
MaxBatchBytes: 4096,
|
||||
WorkerCount: 2,
|
||||
MaxConcurrentBatches: 2,
|
||||
MaxPacketsPerSOCKSPerBatch: 1,
|
||||
MuxRotateEveryBatches: 1,
|
||||
MuxBurstThresholdBytes: 1024,
|
||||
WorkerPollIntervalMS: 200,
|
||||
IdlePollIntervalMS: 1000,
|
||||
PingWarmThresholdMS: 5000,
|
||||
PingBackoffBaseMS: 5000,
|
||||
PingBackoffStepMS: 5000,
|
||||
PingMaxIntervalMS: 60000,
|
||||
MaxQueueBytesPerSOCKS: 4096,
|
||||
HTTPBatchRandomize: false,
|
||||
}
|
||||
}
|
||||
|
||||
func TestSOCKSConnectionStoreDeleteClearsTransportState(t *testing.T) {
|
||||
store := NewSOCKSConnectionStore()
|
||||
chunkPolicy := ChunkPolicy{
|
||||
@@ -71,7 +92,7 @@ func TestSOCKSConnectionStoreDeleteClearsTransportState(t *testing.T) {
|
||||
func TestSOCKSConnectionInboundReorderQueuesAndDrainsInOrder(t *testing.T) {
|
||||
socksConn := &SOCKSConnection{
|
||||
ConnectAccepted: true,
|
||||
PendingInbound: make(map[uint64]PendingInboundPacket),
|
||||
PendingInbound: make(map[uint64][]PendingInboundPacket),
|
||||
}
|
||||
|
||||
packet2 := protocol.NewPacket("client-session", protocol.PacketTypeSOCKSData)
|
||||
@@ -106,12 +127,12 @@ func TestSOCKSConnectionInboundReorderQueuesAndDrainsInOrder(t *testing.T) {
|
||||
|
||||
func TestSOCKSConnectionInboundGapTimeout(t *testing.T) {
|
||||
socksConn := &SOCKSConnection{
|
||||
PendingInbound: make(map[uint64]PendingInboundPacket),
|
||||
PendingInbound: make(map[uint64][]PendingInboundPacket),
|
||||
}
|
||||
socksConn.PendingInbound[5] = PendingInboundPacket{
|
||||
socksConn.PendingInbound[5] = []PendingInboundPacket{{
|
||||
Packet: protocol.Packet{Sequence: 5},
|
||||
QueuedAt: time.Now().Add(-2 * time.Second),
|
||||
}
|
||||
}}
|
||||
|
||||
if !socksConn.hasExpiredInboundGap(500 * time.Millisecond) {
|
||||
t.Fatal("expected inbound gap timeout to trigger")
|
||||
@@ -123,7 +144,7 @@ func TestSOCKSConnectionInboundGapTimeout(t *testing.T) {
|
||||
|
||||
func TestSOCKSConnectionInboundDataWaitsForConnectAck(t *testing.T) {
|
||||
socksConn := &SOCKSConnection{
|
||||
PendingInbound: make(map[uint64]PendingInboundPacket),
|
||||
PendingInbound: make(map[uint64][]PendingInboundPacket),
|
||||
}
|
||||
|
||||
packet1 := protocol.NewPacket("client-session", protocol.PacketTypeSOCKSData)
|
||||
@@ -150,18 +171,10 @@ func TestSOCKSConnectionInboundDataWaitsForConnectAck(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildNextBatchRotatesAcrossConnections(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
MaxChunkSize: 1024,
|
||||
MaxPacketsPerBatch: 1,
|
||||
MaxBatchBytes: 4096,
|
||||
WorkerCount: 1,
|
||||
MaxConcurrentBatches: 1,
|
||||
MaxPacketsPerSOCKSPerBatch: 1,
|
||||
MuxRotateEveryBatches: 1,
|
||||
MuxBurstThresholdBytes: 1024,
|
||||
MaxQueueBytesPerSOCKS: 4096,
|
||||
HTTPBatchRandomize: false,
|
||||
}
|
||||
cfg := testClientConfig()
|
||||
cfg.MaxPacketsPerBatch = 1
|
||||
cfg.WorkerCount = 1
|
||||
cfg.MaxConcurrentBatches = 1
|
||||
|
||||
client := New(cfg, nil)
|
||||
client.chunkPolicy = newChunkPolicy(cfg)
|
||||
@@ -195,18 +208,7 @@ func TestBuildNextBatchRotatesAcrossConnections(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestBuildNextBatchHonorsPerSOCKSPacketLimit(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
MaxChunkSize: 1024,
|
||||
MaxPacketsPerBatch: 4,
|
||||
MaxBatchBytes: 4096,
|
||||
WorkerCount: 2,
|
||||
MaxConcurrentBatches: 2,
|
||||
MaxPacketsPerSOCKSPerBatch: 1,
|
||||
MuxRotateEveryBatches: 1,
|
||||
MuxBurstThresholdBytes: 1024,
|
||||
MaxQueueBytesPerSOCKS: 4096,
|
||||
HTTPBatchRandomize: false,
|
||||
}
|
||||
cfg := testClientConfig()
|
||||
|
||||
client := New(cfg, nil)
|
||||
client.chunkPolicy = newChunkPolicy(cfg)
|
||||
@@ -242,13 +244,10 @@ func TestBuildNextBatchHonorsPerSOCKSPacketLimit(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestEffectiveConcurrentBatchesUsesBurstThreshold(t *testing.T) {
|
||||
cfg := config.Config{
|
||||
WorkerCount: 4,
|
||||
MaxConcurrentBatches: 3,
|
||||
MaxPacketsPerSOCKSPerBatch: 2,
|
||||
MuxRotateEveryBatches: 1,
|
||||
MuxBurstThresholdBytes: 4096,
|
||||
}
|
||||
cfg := testClientConfig()
|
||||
cfg.WorkerCount = 4
|
||||
cfg.MaxConcurrentBatches = 3
|
||||
cfg.MuxBurstThresholdBytes = 4096
|
||||
|
||||
client := New(cfg, nil)
|
||||
if got := client.effectiveConcurrentBatches(1024); got != 1 {
|
||||
@@ -258,3 +257,232 @@ func TestEffectiveConcurrentBatchesUsesBurstThreshold(t *testing.T) {
|
||||
t.Fatalf("expected burst concurrency of 3, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPollBatchSkipsWhenTransportBusy(t *testing.T) {
|
||||
cfg := testClientConfig()
|
||||
client := New(cfg, nil)
|
||||
client.chunkPolicy = newChunkPolicy(cfg)
|
||||
|
||||
socksConn := client.socksConnections.New(client.clientSessionKey, "127.0.0.1:1001", client.chunkPolicy)
|
||||
if err := socksConn.EnqueuePacket(socksConn.BuildSOCKSDataPacket([]byte("busy"), false)); err != nil {
|
||||
t.Fatalf("enqueue packet: %v", err)
|
||||
}
|
||||
|
||||
batch, ok := client.buildPollBatch(client.socksConnections.Snapshot(), queuedBytesAcross(client.socksConnections.Snapshot()))
|
||||
if ok || len(batch.Packets) != 0 {
|
||||
t.Fatal("expected poll batch to be suppressed while queued payload exists")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPollBatchAllowsOnlySinglePingInFlight(t *testing.T) {
|
||||
cfg := testClientConfig()
|
||||
client := New(cfg, nil)
|
||||
client.chunkPolicy = newChunkPolicy(cfg)
|
||||
client.socksConnections.New(client.clientSessionKey, "127.0.0.1:1001", client.chunkPolicy)
|
||||
client.noteMeaningfulActivity(time.Now().Add(-10 * time.Second))
|
||||
|
||||
batch, ok := client.buildPollBatch(client.socksConnections.Snapshot(), 0)
|
||||
if !ok || len(batch.Packets) != 1 || batch.Packets[0].Type != protocol.PacketTypePing {
|
||||
t.Fatal("expected first idle batch to be a ping")
|
||||
}
|
||||
|
||||
batch, ok = client.buildPollBatch(client.socksConnections.Snapshot(), 0)
|
||||
if ok || len(batch.Packets) != 0 {
|
||||
t.Fatal("expected second ping to be suppressed while first ping is still in flight")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPollBatchAllowsSessionPingWithoutActiveConnections(t *testing.T) {
|
||||
cfg := testClientConfig()
|
||||
client := New(cfg, nil)
|
||||
|
||||
now := time.Now()
|
||||
client.noteMeaningfulActivity(now.Add(-10 * time.Second))
|
||||
client.nextPingDueUnixMS.Store(now.Add(-1 * time.Second).UnixMilli())
|
||||
|
||||
batch, ok := client.buildPollBatch(nil, 0)
|
||||
if !ok || len(batch.Packets) != 1 || batch.Packets[0].Type != protocol.PacketTypePing {
|
||||
t.Fatal("expected session-level ping even without active socks connections")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildPollBatchSkipsWithoutSessionActivity(t *testing.T) {
|
||||
cfg := testClientConfig()
|
||||
client := New(cfg, nil)
|
||||
client.nextPingDueUnixMS.Store(time.Now().Add(-1 * time.Second).UnixMilli())
|
||||
|
||||
batch, ok := client.buildPollBatch(nil, 0)
|
||||
if ok || len(batch.Packets) != 0 {
|
||||
t.Fatal("expected ping to stay suppressed before the session has any real activity")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSendPingWhenIdleIntervalHasElapsed(t *testing.T) {
|
||||
cfg := testClientConfig()
|
||||
client := New(cfg, nil)
|
||||
client.chunkPolicy = newChunkPolicy(cfg)
|
||||
client.socksConnections.New(client.clientSessionKey, "127.0.0.1:1001", client.chunkPolicy)
|
||||
|
||||
now := time.Now()
|
||||
client.nextPingDueUnixMS.Store(now.Add(-2 * time.Second).UnixMilli())
|
||||
if !client.shouldSendPing(client.socksConnections.Snapshot(), 0, now) {
|
||||
t.Fatal("expected ping to be due after idle interval elapsed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldNotSendPingBeforeIdleInterval(t *testing.T) {
|
||||
cfg := testClientConfig()
|
||||
client := New(cfg, nil)
|
||||
client.chunkPolicy = newChunkPolicy(cfg)
|
||||
client.socksConnections.New(client.clientSessionKey, "127.0.0.1:1001", client.chunkPolicy)
|
||||
|
||||
now := time.Now()
|
||||
client.nextPingDueUnixMS.Store(now.Add(500 * time.Millisecond).UnixMilli())
|
||||
if client.shouldSendPing(client.socksConnections.Snapshot(), 0, now) {
|
||||
t.Fatal("expected ping to stay suppressed until idle interval elapses")
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldSendPingWithOnlyInFlightPackets(t *testing.T) {
|
||||
cfg := testClientConfig()
|
||||
client := New(cfg, nil)
|
||||
client.chunkPolicy = newChunkPolicy(cfg)
|
||||
socksConn := client.socksConnections.New(client.clientSessionKey, "127.0.0.1:1001", client.chunkPolicy)
|
||||
|
||||
packet := socksConn.BuildSOCKSDataPacket([]byte("hello"), false)
|
||||
item := &SOCKSOutboundQueueItem{
|
||||
IdentityKey: protocol.PacketIdentityKey(
|
||||
packet.ClientSessionKey,
|
||||
packet.SOCKSID,
|
||||
packet.Type,
|
||||
packet.Sequence,
|
||||
packet.FragmentID,
|
||||
),
|
||||
Packet: packet,
|
||||
QueuedAt: time.Now(),
|
||||
SentAt: time.Now(),
|
||||
PayloadSize: len(packet.Payload),
|
||||
}
|
||||
socksConn.MarkInFlight([]*SOCKSOutboundQueueItem{item})
|
||||
|
||||
now := time.Now()
|
||||
client.noteMeaningfulActivity(now.Add(-10 * time.Second))
|
||||
client.nextPingDueUnixMS.Store(now.Add(-1 * time.Second).UnixMilli())
|
||||
|
||||
if !client.shouldSendPing(client.socksConnections.Snapshot(), 0, now) {
|
||||
t.Fatal("expected ping to be allowed while only in-flight packets remain")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIdleIntervalForStreakBacksOffWithIdlePongs(t *testing.T) {
|
||||
cfg := testClientConfig()
|
||||
client := New(cfg, nil)
|
||||
|
||||
if got := client.idleIntervalForStreak(0); got != 5*time.Second {
|
||||
t.Fatalf("expected base backoff interval, got %v", got)
|
||||
}
|
||||
|
||||
if got := client.idleIntervalForStreak(1); got != 10*time.Second {
|
||||
t.Fatalf("expected first stepped backoff interval, got %v", got)
|
||||
}
|
||||
|
||||
if got := client.idleIntervalForStreak(20); got != 60*time.Second {
|
||||
t.Fatalf("expected capped backoff interval, got %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletePingWithPongIncrementsStreakOnlyWithoutRealTraffic(t *testing.T) {
|
||||
cfg := testClientConfig()
|
||||
client := New(cfg, nil)
|
||||
now := time.Now()
|
||||
|
||||
client.noteMeaningfulActivity(now.Add(-10 * time.Second))
|
||||
if !client.tryBeginPing(now) {
|
||||
t.Fatal("expected ping to start")
|
||||
}
|
||||
client.completePingWithPong()
|
||||
if got := client.idlePongStreak.Load(); got != 1 {
|
||||
t.Fatalf("expected pong streak to increment to 1, got %d", got)
|
||||
}
|
||||
nextDue := client.nextPingDueUnixMS.Load()
|
||||
if nextDue <= now.UnixMilli() {
|
||||
t.Fatal("expected next ping due to be scheduled in the future after idle pong")
|
||||
}
|
||||
|
||||
client.noteMeaningfulActivity(now.Add(1 * time.Second))
|
||||
if !client.tryBeginPing(now.Add(2 * time.Second)) {
|
||||
t.Fatal("expected second ping to start")
|
||||
}
|
||||
client.noteMeaningfulActivity(now.Add(3 * time.Second))
|
||||
client.completePingWithPong()
|
||||
if got := client.idlePongStreak.Load(); got != 0 {
|
||||
t.Fatalf("expected pong streak reset after real traffic, got %d", got)
|
||||
}
|
||||
if client.nextPingDueUnixMS.Load() <= now.UnixMilli() {
|
||||
t.Fatal("expected next ping due to be rescheduled after meaningful traffic")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletePingWithPongStaysAggressiveBeforeWarmThreshold(t *testing.T) {
|
||||
cfg := testClientConfig()
|
||||
client := New(cfg, nil)
|
||||
now := time.Now()
|
||||
|
||||
client.noteMeaningfulActivity(now.Add(-3 * time.Second))
|
||||
if !client.tryBeginPing(now) {
|
||||
t.Fatal("expected ping to start")
|
||||
}
|
||||
|
||||
client.completePingWithPong()
|
||||
|
||||
if got := client.idlePongStreak.Load(); got != 0 {
|
||||
t.Fatalf("expected pong streak to stay at 0 before warm threshold, got %d", got)
|
||||
}
|
||||
|
||||
nextDue := client.nextPingDueUnixMS.Load()
|
||||
expectedMin := now.Add(900 * time.Millisecond).UnixMilli()
|
||||
expectedMax := now.Add(1100 * time.Millisecond).UnixMilli()
|
||||
if nextDue < expectedMin || nextDue > expectedMax {
|
||||
t.Fatalf("expected aggressive next ping around idle interval, got %d", nextDue)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInboundReorderAllowsCloseReadAndCloseWriteOnSameSequence(t *testing.T) {
|
||||
cfg := testClientConfig()
|
||||
client := New(cfg, nil)
|
||||
client.chunkPolicy = newChunkPolicy(cfg)
|
||||
|
||||
socksConn := client.socksConnections.New(client.clientSessionKey, "127.0.0.1:1001", client.chunkPolicy)
|
||||
socksConn.ConnectAccepted = true
|
||||
|
||||
closeWrite := protocol.NewPacket(client.clientSessionKey, protocol.PacketTypeSOCKSCloseWrite)
|
||||
closeWrite.SOCKSID = socksConn.ID
|
||||
closeWrite.Sequence = 2
|
||||
|
||||
closeRead := protocol.NewPacket(client.clientSessionKey, protocol.PacketTypeSOCKSCloseRead)
|
||||
closeRead.SOCKSID = socksConn.ID
|
||||
closeRead.Sequence = 2
|
||||
|
||||
ready, duplicate, overflow := socksConn.queueInboundPacket(closeWrite, 8)
|
||||
if duplicate || overflow || len(ready) != 0 {
|
||||
t.Fatalf("expected first close packet to buffer, duplicate=%t overflow=%t ready=%d", duplicate, overflow, len(ready))
|
||||
}
|
||||
|
||||
ready, duplicate, overflow = socksConn.queueInboundPacket(closeRead, 8)
|
||||
if duplicate || overflow || len(ready) != 0 {
|
||||
t.Fatalf("expected second close packet on same sequence to buffer, duplicate=%t overflow=%t ready=%d", duplicate, overflow, len(ready))
|
||||
}
|
||||
|
||||
data := protocol.NewPacket(client.clientSessionKey, protocol.PacketTypeSOCKSData)
|
||||
data.SOCKSID = socksConn.ID
|
||||
data.Sequence = 1
|
||||
data.Payload = []byte("ok")
|
||||
|
||||
ready, duplicate, overflow = socksConn.queueInboundPacket(data, 8)
|
||||
if duplicate || overflow || len(ready) != 3 {
|
||||
t.Fatalf("expected data and both close packets to drain, duplicate=%t overflow=%t ready=%d", duplicate, overflow, len(ready))
|
||||
}
|
||||
if ready[0].Type != protocol.PacketTypeSOCKSData || ready[1].Type != protocol.PacketTypeSOCKSCloseRead || ready[2].Type != protocol.PacketTypeSOCKSCloseWrite {
|
||||
t.Fatalf("unexpected drain order: %s, %s, %s", ready[0].Type, ready[1].Type, ready[2].Type)
|
||||
}
|
||||
}
|
||||
|
||||
+15
-15
@@ -39,21 +39,21 @@ type SOCKSConnection struct {
|
||||
CloseWriteSent bool
|
||||
ResetSent bool
|
||||
|
||||
LocalConn net.Conn
|
||||
localWriteMu sync.Mutex
|
||||
localCloseMu sync.Mutex
|
||||
reorderMu sync.Mutex
|
||||
localReadEOF bool
|
||||
localWriteEOF bool
|
||||
closedC chan struct{}
|
||||
closeOnce sync.Once
|
||||
connectResultC chan error
|
||||
queueMu sync.Mutex
|
||||
OutboundQueue []*SOCKSOutboundQueueItem
|
||||
QueuedBytes int
|
||||
InFlight map[string]*SOCKSOutboundQueueItem
|
||||
LocalConn net.Conn
|
||||
localWriteMu sync.Mutex
|
||||
localCloseMu sync.Mutex
|
||||
reorderMu sync.Mutex
|
||||
localReadEOF bool
|
||||
localWriteEOF bool
|
||||
closedC chan struct{}
|
||||
closeOnce sync.Once
|
||||
connectResultC chan error
|
||||
queueMu sync.Mutex
|
||||
OutboundQueue []*SOCKSOutboundQueueItem
|
||||
QueuedBytes int
|
||||
InFlight map[string]*SOCKSOutboundQueueItem
|
||||
NextInboundSequence uint64
|
||||
PendingInbound map[uint64]PendingInboundPacket
|
||||
PendingInbound map[uint64][]PendingInboundPacket
|
||||
}
|
||||
|
||||
type PendingInboundPacket struct {
|
||||
@@ -93,7 +93,7 @@ func (s *SOCKSConnectionStore) New(clientSessionKey string, clientAddress string
|
||||
closedC: make(chan struct{}),
|
||||
connectResultC: make(chan error, 1),
|
||||
InFlight: make(map[string]*SOCKSOutboundQueueItem),
|
||||
PendingInbound: make(map[uint64]PendingInboundPacket),
|
||||
PendingInbound: make(map[uint64][]PendingInboundPacket),
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
|
||||
@@ -100,6 +100,12 @@ func (s *SOCKSConnection) QueueSnapshot() (items int, bytes int) {
|
||||
return len(s.OutboundQueue), s.QueuedBytes
|
||||
}
|
||||
|
||||
func (s *SOCKSConnection) InFlightCount() int {
|
||||
s.queueMu.Lock()
|
||||
defer s.queueMu.Unlock()
|
||||
return len(s.InFlight)
|
||||
}
|
||||
|
||||
func (s *SOCKSConnection) DequeuePacket() *SOCKSOutboundQueueItem {
|
||||
s.queueMu.Lock()
|
||||
defer s.queueMu.Unlock()
|
||||
|
||||
@@ -49,6 +49,10 @@ type Config struct {
|
||||
HTTPRequestTimeoutMS int
|
||||
WorkerPollIntervalMS int
|
||||
IdlePollIntervalMS int
|
||||
PingWarmThresholdMS int
|
||||
PingBackoffBaseMS int
|
||||
PingBackoffStepMS int
|
||||
PingMaxIntervalMS int
|
||||
MaxQueueBytesPerSOCKS int
|
||||
AckTimeoutMS int
|
||||
MaxRetryCount int
|
||||
@@ -88,6 +92,10 @@ func Load(path string) (Config, error) {
|
||||
HTTPRequestTimeoutMS: 15000,
|
||||
WorkerPollIntervalMS: 200,
|
||||
IdlePollIntervalMS: 1000,
|
||||
PingWarmThresholdMS: 5000,
|
||||
PingBackoffBaseMS: 5000,
|
||||
PingBackoffStepMS: 5000,
|
||||
PingMaxIntervalMS: 60000,
|
||||
MaxQueueBytesPerSOCKS: 1024 * 1024,
|
||||
AckTimeoutMS: 5000,
|
||||
MaxRetryCount: 5,
|
||||
@@ -292,6 +300,34 @@ func Load(path string) (Config, error) {
|
||||
}
|
||||
|
||||
cfg.IdlePollIntervalMS = interval
|
||||
case "PING_WARM_THRESHOLD_MS":
|
||||
threshold, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("parse PING_WARM_THRESHOLD_MS: %w", err)
|
||||
}
|
||||
|
||||
cfg.PingWarmThresholdMS = threshold
|
||||
case "PING_BACKOFF_BASE_MS":
|
||||
interval, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("parse PING_BACKOFF_BASE_MS: %w", err)
|
||||
}
|
||||
|
||||
cfg.PingBackoffBaseMS = interval
|
||||
case "PING_BACKOFF_STEP_MS":
|
||||
interval, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("parse PING_BACKOFF_STEP_MS: %w", err)
|
||||
}
|
||||
|
||||
cfg.PingBackoffStepMS = interval
|
||||
case "PING_MAX_INTERVAL_MS":
|
||||
interval, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("parse PING_MAX_INTERVAL_MS: %w", err)
|
||||
}
|
||||
|
||||
cfg.PingMaxIntervalMS = interval
|
||||
case "MAX_QUEUE_BYTES_PER_SOCKS":
|
||||
size, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
@@ -407,6 +443,18 @@ func (c Config) ValidateClient() error {
|
||||
if c.IdlePollIntervalMS < c.WorkerPollIntervalMS {
|
||||
return fmt.Errorf("IDLE_POLL_INTERVAL_MS must be >= WORKER_POLL_INTERVAL_MS")
|
||||
}
|
||||
if c.PingWarmThresholdMS < 1 {
|
||||
return fmt.Errorf("invalid PING_WARM_THRESHOLD_MS: %d", c.PingWarmThresholdMS)
|
||||
}
|
||||
if c.PingBackoffBaseMS < c.IdlePollIntervalMS {
|
||||
return fmt.Errorf("PING_BACKOFF_BASE_MS must be >= IDLE_POLL_INTERVAL_MS")
|
||||
}
|
||||
if c.PingBackoffStepMS < 1 {
|
||||
return fmt.Errorf("invalid PING_BACKOFF_STEP_MS: %d", c.PingBackoffStepMS)
|
||||
}
|
||||
if c.PingMaxIntervalMS < c.PingBackoffBaseMS {
|
||||
return fmt.Errorf("PING_MAX_INTERVAL_MS must be >= PING_BACKOFF_BASE_MS")
|
||||
}
|
||||
|
||||
if c.AckTimeoutMS < 1 {
|
||||
return fmt.Errorf("invalid ACK_TIMEOUT_MS: %d", c.AckTimeoutMS)
|
||||
|
||||
+87
-36
@@ -39,29 +39,29 @@ type ClientSession struct {
|
||||
}
|
||||
|
||||
type SOCKSState struct {
|
||||
ID uint64
|
||||
CreatedAt time.Time
|
||||
LastActivityAt time.Time
|
||||
Target *protocol.Target
|
||||
ConnectSeen bool
|
||||
ConnectAcked bool
|
||||
CloseReadSeen bool
|
||||
CloseWriteSeen bool
|
||||
ResetSeen bool
|
||||
ReceivedBytes uint64
|
||||
LastSequenceSeen uint64
|
||||
ID uint64
|
||||
CreatedAt time.Time
|
||||
LastActivityAt time.Time
|
||||
Target *protocol.Target
|
||||
ConnectSeen bool
|
||||
ConnectAcked bool
|
||||
CloseReadSeen bool
|
||||
CloseWriteSeen bool
|
||||
ResetSeen bool
|
||||
ReceivedBytes uint64
|
||||
LastSequenceSeen uint64
|
||||
NextInboundSequence uint64
|
||||
OutboundSequence uint64
|
||||
PendingInbound map[uint64]PendingInboundPacket
|
||||
UpstreamConn net.Conn
|
||||
upstreamWriteMu sync.Mutex
|
||||
upstreamCloseMu sync.Mutex
|
||||
upstreamReadEOF bool
|
||||
upstreamWriteEOF bool
|
||||
queueMu sync.Mutex
|
||||
OutboundQueue []protocol.Packet
|
||||
QueuedBytes int
|
||||
MaxQueueBytes int
|
||||
OutboundSequence uint64
|
||||
PendingInbound map[uint64][]PendingInboundPacket
|
||||
UpstreamConn net.Conn
|
||||
upstreamWriteMu sync.Mutex
|
||||
upstreamCloseMu sync.Mutex
|
||||
upstreamReadEOF bool
|
||||
upstreamWriteEOF bool
|
||||
queueMu sync.Mutex
|
||||
OutboundQueue []protocol.Packet
|
||||
QueuedBytes int
|
||||
MaxQueueBytes int
|
||||
}
|
||||
|
||||
type PendingInboundPacket struct {
|
||||
@@ -230,7 +230,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
|
||||
ConnectSeen: true,
|
||||
LastSequenceSeen: packet.Sequence,
|
||||
MaxQueueBytes: s.cfg.MaxServerQueueBytes,
|
||||
PendingInbound: make(map[uint64]PendingInboundPacket),
|
||||
PendingInbound: make(map[uint64][]PendingInboundPacket),
|
||||
}
|
||||
session.SOCKSConnections[packet.SOCKSID] = socksState
|
||||
} else {
|
||||
@@ -238,7 +238,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
|
||||
socksState.Target = packet.Target
|
||||
socksState.ConnectSeen = true
|
||||
if socksState.PendingInbound == nil {
|
||||
socksState.PendingInbound = make(map[uint64]PendingInboundPacket)
|
||||
socksState.PendingInbound = make(map[uint64][]PendingInboundPacket)
|
||||
}
|
||||
if packet.Sequence > socksState.LastSequenceSeen {
|
||||
socksState.LastSequenceSeen = packet.Sequence
|
||||
@@ -380,7 +380,7 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot
|
||||
socksState := session.SOCKSConnections[packet.SOCKSID]
|
||||
if socksState != nil {
|
||||
if socksState.PendingInbound == nil {
|
||||
socksState.PendingInbound = make(map[uint64]PendingInboundPacket)
|
||||
socksState.PendingInbound = make(map[uint64][]PendingInboundPacket)
|
||||
}
|
||||
return socksState
|
||||
}
|
||||
@@ -392,7 +392,7 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot
|
||||
Target: packet.Target,
|
||||
LastSequenceSeen: packet.Sequence,
|
||||
MaxQueueBytes: s.cfg.MaxServerQueueBytes,
|
||||
PendingInbound: make(map[uint64]PendingInboundPacket),
|
||||
PendingInbound: make(map[uint64][]PendingInboundPacket),
|
||||
}
|
||||
session.SOCKSConnections[packet.SOCKSID] = socksState
|
||||
s.log.Debugf(
|
||||
@@ -625,16 +625,17 @@ func (s *SOCKSState) queueInboundPacketLocked(packet protocol.Packet, now time.T
|
||||
if packet.Sequence < expected {
|
||||
return nil, true, false
|
||||
}
|
||||
if _, exists := s.PendingInbound[packet.Sequence]; exists {
|
||||
pendingForSequence := s.PendingInbound[packet.Sequence]
|
||||
if containsPendingInboundPacketLocked(pendingForSequence, packet) {
|
||||
return nil, true, false
|
||||
}
|
||||
if len(s.PendingInbound) >= maxBuffered {
|
||||
if bufferedInboundPacketCountLocked(s.PendingInbound) >= maxBuffered {
|
||||
return nil, false, true
|
||||
}
|
||||
s.PendingInbound[packet.Sequence] = PendingInboundPacket{
|
||||
s.PendingInbound[packet.Sequence] = append(s.PendingInbound[packet.Sequence], PendingInboundPacket{
|
||||
Packet: packet,
|
||||
QueuedAt: now,
|
||||
}
|
||||
})
|
||||
if !s.ConnectAcked {
|
||||
return nil, false, false
|
||||
}
|
||||
@@ -645,11 +646,14 @@ func (s *SOCKSState) drainReadyInboundLocked() []protocol.Packet {
|
||||
expected := s.expectedInboundSequenceLocked()
|
||||
ready := make([]protocol.Packet, 0)
|
||||
for {
|
||||
pending, ok := s.PendingInbound[expected]
|
||||
if !ok {
|
||||
pendingPackets, ok := s.PendingInbound[expected]
|
||||
if !ok || len(pendingPackets) == 0 {
|
||||
break
|
||||
}
|
||||
ready = append(ready, pending.Packet)
|
||||
sortPendingInboundPacketsLocked(pendingPackets)
|
||||
for _, pending := range pendingPackets {
|
||||
ready = append(ready, pending.Packet)
|
||||
}
|
||||
delete(s.PendingInbound, expected)
|
||||
expected++
|
||||
}
|
||||
@@ -658,15 +662,62 @@ func (s *SOCKSState) drainReadyInboundLocked() []protocol.Packet {
|
||||
}
|
||||
|
||||
func (s *SOCKSState) hasExpiredInboundGapLocked(now time.Time, timeout time.Duration) bool {
|
||||
for _, pending := range s.PendingInbound {
|
||||
if now.Sub(pending.QueuedAt) >= timeout {
|
||||
clear(s.PendingInbound)
|
||||
for _, pendingPackets := range s.PendingInbound {
|
||||
for _, pending := range pendingPackets {
|
||||
if now.Sub(pending.QueuedAt) >= timeout {
|
||||
clear(s.PendingInbound)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsPendingInboundPacketLocked(pendingPackets []PendingInboundPacket, packet protocol.Packet) bool {
|
||||
for _, pending := range pendingPackets {
|
||||
if pending.Packet.Type == packet.Type &&
|
||||
pending.Packet.FragmentID == packet.FragmentID &&
|
||||
pending.Packet.TotalFragments == packet.TotalFragments {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func bufferedInboundPacketCountLocked(pending map[uint64][]PendingInboundPacket) int {
|
||||
total := 0
|
||||
for _, pendingPackets := range pending {
|
||||
total += len(pendingPackets)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func sortPendingInboundPacketsLocked(pendingPackets []PendingInboundPacket) {
|
||||
for i := 1; i < len(pendingPackets); i++ {
|
||||
current := pendingPackets[i]
|
||||
j := i - 1
|
||||
for ; j >= 0 && inboundPacketSortOrderLocked(current.Packet.Type) < inboundPacketSortOrderLocked(pendingPackets[j].Packet.Type); j-- {
|
||||
pendingPackets[j+1] = pendingPackets[j]
|
||||
}
|
||||
pendingPackets[j+1] = current
|
||||
}
|
||||
}
|
||||
|
||||
func inboundPacketSortOrderLocked(packetType protocol.PacketType) int {
|
||||
switch packetType {
|
||||
case protocol.PacketTypeSOCKSData:
|
||||
return 0
|
||||
case protocol.PacketTypeSOCKSCloseRead:
|
||||
return 1
|
||||
case protocol.PacketTypeSOCKSCloseWrite:
|
||||
return 2
|
||||
case protocol.PacketTypeSOCKSRST:
|
||||
return 3
|
||||
default:
|
||||
return 4
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SOCKSState) enqueueOutboundData(clientSessionKey string, payload []byte, final bool) bool {
|
||||
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData)
|
||||
packet.SOCKSID = s.ID
|
||||
|
||||
@@ -64,9 +64,9 @@ func TestDrainSessionOutboundLockedRespectsGlobalLimits(t *testing.T) {
|
||||
|
||||
func TestSOCKSStateInboundReorderQueuesUntilGapFilled(t *testing.T) {
|
||||
socksState := &SOCKSState{
|
||||
ConnectAcked: true,
|
||||
PendingInbound: make(map[uint64]PendingInboundPacket),
|
||||
MaxQueueBytes: 1024,
|
||||
ConnectAcked: true,
|
||||
PendingInbound: make(map[uint64][]PendingInboundPacket),
|
||||
MaxQueueBytes: 1024,
|
||||
}
|
||||
|
||||
packet2 := testDataPacket("client-session", 1, 2, "two")
|
||||
@@ -93,12 +93,12 @@ func TestSOCKSStateInboundReorderQueuesUntilGapFilled(t *testing.T) {
|
||||
|
||||
func TestSOCKSStateInboundGapTimeout(t *testing.T) {
|
||||
socksState := &SOCKSState{
|
||||
PendingInbound: make(map[uint64]PendingInboundPacket),
|
||||
PendingInbound: make(map[uint64][]PendingInboundPacket),
|
||||
}
|
||||
socksState.PendingInbound[3] = PendingInboundPacket{
|
||||
socksState.PendingInbound[3] = []PendingInboundPacket{{
|
||||
Packet: testDataPacket("client-session", 1, 3, "late"),
|
||||
QueuedAt: time.Now().Add(-2 * time.Second),
|
||||
}
|
||||
}}
|
||||
|
||||
if !socksState.hasExpiredInboundGapLocked(time.Now(), 500*time.Millisecond) {
|
||||
t.Fatal("expected inbound gap timeout to trigger")
|
||||
@@ -110,7 +110,7 @@ func TestSOCKSStateInboundGapTimeout(t *testing.T) {
|
||||
|
||||
func TestSOCKSStateInboundDataWaitsForConnect(t *testing.T) {
|
||||
socksState := &SOCKSState{
|
||||
PendingInbound: make(map[uint64]PendingInboundPacket),
|
||||
PendingInbound: make(map[uint64][]PendingInboundPacket),
|
||||
}
|
||||
|
||||
packet1 := testDataPacket("client-session", 1, 1, "one")
|
||||
@@ -132,6 +132,40 @@ func TestSOCKSStateInboundDataWaitsForConnect(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSOCKSStateInboundReorderAllowsMultiplePacketTypesPerSequence(t *testing.T) {
|
||||
socksState := &SOCKSState{
|
||||
ConnectAcked: true,
|
||||
PendingInbound: make(map[uint64][]PendingInboundPacket),
|
||||
}
|
||||
|
||||
closeWrite := protocol.NewPacket("client-session", protocol.PacketTypeSOCKSCloseWrite)
|
||||
closeWrite.SOCKSID = 1
|
||||
closeWrite.Sequence = 2
|
||||
|
||||
closeRead := protocol.NewPacket("client-session", protocol.PacketTypeSOCKSCloseRead)
|
||||
closeRead.SOCKSID = 1
|
||||
closeRead.Sequence = 2
|
||||
|
||||
ready, duplicate, overflow := socksState.queueInboundPacketLocked(closeWrite, time.Now(), 8)
|
||||
if duplicate || overflow || len(ready) != 0 {
|
||||
t.Fatalf("expected first close packet to buffer, duplicate=%t overflow=%t ready=%d", duplicate, overflow, len(ready))
|
||||
}
|
||||
|
||||
ready, duplicate, overflow = socksState.queueInboundPacketLocked(closeRead, time.Now(), 8)
|
||||
if duplicate || overflow || len(ready) != 0 {
|
||||
t.Fatalf("expected second close packet on same sequence to buffer, duplicate=%t overflow=%t ready=%d", duplicate, overflow, len(ready))
|
||||
}
|
||||
|
||||
data := testDataPacket("client-session", 1, 1, "one")
|
||||
ready, duplicate, overflow = socksState.queueInboundPacketLocked(data, time.Now(), 8)
|
||||
if duplicate || overflow || len(ready) != 3 {
|
||||
t.Fatalf("expected data and both close packets to drain, duplicate=%t overflow=%t ready=%d", duplicate, overflow, len(ready))
|
||||
}
|
||||
if ready[0].Type != protocol.PacketTypeSOCKSData || ready[1].Type != protocol.PacketTypeSOCKSCloseRead || ready[2].Type != protocol.PacketTypeSOCKSCloseWrite {
|
||||
t.Fatalf("unexpected drain order: %s, %s, %s", ready[0].Type, ready[1].Type, ready[2].Type)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSOCKSStateReleaseClearsQueueState(t *testing.T) {
|
||||
socksState := &SOCKSState{
|
||||
Target: &protocol.Target{Host: "example.com", Port: 443},
|
||||
|
||||
Reference in New Issue
Block a user