Add mux-aware batching limits and burst concurrency control

This commit is contained in:
Amin.MasterkinG
2026-04-21 10:26:42 +03:30
parent 136ddef09a
commit 236ae711c3
5 changed files with 349 additions and 91 deletions
+1
View File
@@ -33,6 +33,7 @@ type Client struct {
workCh chan struct{}
lastPollUnixMS atomic.Int64
activeBatches atomic.Int64
batchCursor atomic.Uint64
}
+116 -15
View File
@@ -12,6 +12,7 @@ import (
"fmt"
"io"
"net/http"
"sort"
"sync"
"time"
@@ -46,8 +47,6 @@ func (c *Client) startSendWorkers(ctx context.Context, wg *sync.WaitGroup) {
}
func (w *sendWorker) run(ctx context.Context, c *Client) {
pollInterval := time.Duration(c.cfg.WorkerPollIntervalMS) * time.Millisecond
for {
select {
case <-ctx.Done():
@@ -57,16 +56,27 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
c.reclaimExpiredInFlight()
c.reclaimExpiredReorder()
batch, selected := c.buildNextBatch()
connections := c.socksConnections.Snapshot()
totalQueuedBytes := queuedBytesAcross(connections)
waitInterval := c.effectiveWaitInterval(totalQueuedBytes)
if !c.tryAcquireBatchSlot(totalQueuedBytes) {
c.waitForSendWork(ctx, c.jitterDuration(waitInterval))
continue
}
batch, selected := c.buildNextBatch(connections, totalQueuedBytes)
if len(batch.Packets) == 0 {
c.waitForSendWork(ctx, c.jitterDuration(pollInterval))
c.releaseBatchSlot()
c.waitForSendWork(ctx, c.jitterDuration(waitInterval))
continue
}
if err := batch.Validate(); err != nil {
c.log.Errorf("<red>worker=<cyan>%d</cyan> invalid batch: <cyan>%v</cyan></red>", w.id, err)
c.requeueSelected(selected)
c.waitForSendWork(ctx, c.jitterDuration(pollInterval))
c.releaseBatchSlot()
c.waitForSendWork(ctx, c.jitterDuration(waitInterval))
continue
}
@@ -76,16 +86,19 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
if err != nil {
c.log.Errorf("<red>worker=<cyan>%d</cyan> encrypt batch failed: <cyan>%v</cyan></red>", w.id, err)
c.requeueSelected(selected)
c.waitForSendWork(ctx, c.jitterDuration(pollInterval))
c.releaseBatchSlot()
c.waitForSendWork(ctx, c.jitterDuration(waitInterval))
continue
}
if err := w.postBatch(ctx, c, batch, body); err != nil {
c.log.Warnf("<yellow>worker=<cyan>%d</cyan> send failed for batch=<cyan>%s</cyan>: <cyan>%v</cyan></yellow>", w.id, batch.BatchID, err)
c.requeueSelected(selected)
c.waitForSendWork(ctx, c.jitterDuration(pollInterval))
c.releaseBatchSlot()
c.waitForSendWork(ctx, c.jitterDuration(waitInterval))
continue
}
c.releaseBatchSlot()
c.log.Debugf(
"<green>worker=<cyan>%d</cyan> sent batch=<cyan>%s</cyan> packets=<cyan>%d</cyan> bytes=<cyan>%d</cyan></green>",
@@ -105,20 +118,30 @@ func (c *Client) waitForSendWork(ctx context.Context, interval time.Duration) {
}
}
func (c *Client) buildNextBatch() (protocol.Batch, []dequeuedPacket) {
connections := c.socksConnections.Snapshot()
func (c *Client) buildNextBatch(connections []*SOCKSConnection, totalQueuedBytes int) (protocol.Batch, []dequeuedPacket) {
if len(connections) == 0 {
return protocol.Batch{}, nil
}
sort.Slice(connections, func(i, j int) bool {
return connections[i].ID < connections[j].ID
})
start := 0
if len(connections) > 1 {
start = int(c.batchCursor.Add(1)-1) % len(connections)
rotationEvery := c.cfg.MuxRotateEveryBatches
if rotationEvery < 1 {
rotationEvery = 1
}
turn := c.batchCursor.Add(1) - 1
start = int((turn / uint64(rotationEvery)) % uint64(len(connections)))
}
maxPackets, maxBatchBytes := c.effectiveBatchLimits()
maxPackets, maxBatchBytes := c.effectiveBatchLimits(totalQueuedBytes)
maxPerSOCKS := c.cfg.MaxPacketsPerSOCKSPerBatch
selected := make([]dequeuedPacket, 0, maxPackets)
packets := make([]protocol.Packet, 0, maxPackets)
selectedPerSOCKS := make(map[uint64]int, len(connections))
totalBytes := 0
for len(selected) < maxPackets {
@@ -130,6 +153,9 @@ func (c *Client) buildNextBatch() (protocol.Batch, []dequeuedPacket) {
}
socksConn := connections[(start+offset)%len(connections)]
if selectedPerSOCKS[socksConn.ID] >= maxPerSOCKS {
continue
}
item := socksConn.DequeuePacket()
if item == nil {
@@ -147,6 +173,7 @@ func (c *Client) buildNextBatch() (protocol.Batch, []dequeuedPacket) {
item: item,
})
packets = append(packets, item.Packet)
selectedPerSOCKS[socksConn.ID]++
totalBytes += packetBytes
progress = true
}
@@ -157,7 +184,7 @@ func (c *Client) buildNextBatch() (protocol.Batch, []dequeuedPacket) {
}
if len(packets) == 0 {
if pingBatch, ok := c.buildPollBatch(connections); ok {
if pingBatch, ok := c.buildPollBatch(connections, totalQueuedBytes); ok {
return pingBatch, nil
}
return protocol.Batch{}, nil
@@ -167,7 +194,7 @@ func (c *Client) buildNextBatch() (protocol.Batch, []dequeuedPacket) {
return batch, selected
}
func (c *Client) buildPollBatch(connections []*SOCKSConnection) (protocol.Batch, bool) {
func (c *Client) buildPollBatch(connections []*SOCKSConnection, totalQueuedBytes int) (protocol.Batch, bool) {
if len(connections) == 0 {
return protocol.Batch{}, false
}
@@ -175,7 +202,7 @@ func (c *Client) buildPollBatch(connections []*SOCKSConnection) (protocol.Batch,
now := time.Now()
nowUnixMS := now.UnixMilli()
lastUnixMS := c.lastPollUnixMS.Load()
minInterval := c.jitterDuration(time.Duration(c.cfg.IdlePollIntervalMS) * time.Millisecond)
minInterval := c.jitterDuration(c.effectiveIdlePollInterval(totalQueuedBytes))
if lastUnixMS > 0 && nowUnixMS-lastUnixMS < minInterval.Milliseconds() {
return protocol.Batch{}, false
}
@@ -190,9 +217,17 @@ func (c *Client) buildPollBatch(connections []*SOCKSConnection) (protocol.Batch,
return batch, true
}
func (c *Client) effectiveBatchLimits() (int, int) {
func (c *Client) effectiveBatchLimits(totalQueuedBytes int) (int, int) {
maxPackets := c.cfg.MaxPacketsPerBatch
maxBatchBytes := c.cfg.MaxBatchBytes
if totalQueuedBytes < c.cfg.MuxBurstThresholdBytes {
if reducedPackets := maxPackets / 2; reducedPackets >= 1 {
maxPackets = reducedPackets
}
if reducedBytes := maxBatchBytes / 2; reducedBytes >= c.cfg.MaxChunkSize {
maxBatchBytes = reducedBytes
}
}
if !c.cfg.HTTPBatchRandomize {
return maxPackets, maxBatchBytes
}
@@ -214,6 +249,72 @@ func (c *Client) effectiveBatchLimits() (int, int) {
return maxPackets, maxBatchBytes
}
func (c *Client) effectiveWaitInterval(totalQueuedBytes int) time.Duration {
interval := time.Duration(c.cfg.WorkerPollIntervalMS) * time.Millisecond
if totalQueuedBytes >= c.cfg.MuxBurstThresholdBytes {
if burst := interval / 2; burst >= 25*time.Millisecond {
return burst
}
return 25 * time.Millisecond
}
return interval
}
func (c *Client) effectiveIdlePollInterval(totalQueuedBytes int) time.Duration {
interval := time.Duration(c.cfg.IdlePollIntervalMS) * time.Millisecond
if totalQueuedBytes >= c.cfg.MuxBurstThresholdBytes {
if burst := interval / 2; burst >= time.Duration(c.cfg.WorkerPollIntervalMS)*time.Millisecond {
return burst
}
}
return interval
}
func (c *Client) effectiveConcurrentBatches(totalQueuedBytes int) int {
if totalQueuedBytes >= c.cfg.MuxBurstThresholdBytes {
return c.cfg.MaxConcurrentBatches
}
return 1
}
func (c *Client) tryAcquireBatchSlot(totalQueuedBytes int) bool {
limit := c.effectiveConcurrentBatches(totalQueuedBytes)
if limit < 1 {
limit = 1
}
for {
current := c.activeBatches.Load()
if int(current) >= limit {
return false
}
if c.activeBatches.CompareAndSwap(current, current+1) {
return true
}
}
}
func (c *Client) releaseBatchSlot() {
for {
current := c.activeBatches.Load()
if current <= 0 {
return
}
if c.activeBatches.CompareAndSwap(current, current-1) {
return
}
}
}
func queuedBytesAcross(connections []*SOCKSConnection) int {
total := 0
for _, socksConn := range connections {
_, queuedBytes := socksConn.QueueSnapshot()
total += queuedBytes
}
return total
}
func (c *Client) jitterDuration(base time.Duration) time.Duration {
if base <= 0 || c.cfg.HTTPTimingJitterMS <= 0 {
return base
+77 -7
View File
@@ -151,12 +151,16 @@ func TestSOCKSConnectionInboundDataWaitsForConnectAck(t *testing.T) {
func TestBuildNextBatchRotatesAcrossConnections(t *testing.T) {
cfg := config.Config{
MaxChunkSize: 1024,
MaxPacketsPerBatch: 1,
MaxBatchBytes: 4096,
WorkerCount: 1,
MaxQueueBytesPerSOCKS: 4096,
HTTPBatchRandomize: false,
MaxChunkSize: 1024,
MaxPacketsPerBatch: 1,
MaxBatchBytes: 4096,
WorkerCount: 1,
MaxConcurrentBatches: 1,
MaxPacketsPerSOCKSPerBatch: 1,
MuxRotateEveryBatches: 1,
MuxBurstThresholdBytes: 1024,
MaxQueueBytesPerSOCKS: 4096,
HTTPBatchRandomize: false,
}
client := New(cfg, nil)
@@ -174,7 +178,8 @@ func TestBuildNextBatchRotatesAcrossConnections(t *testing.T) {
seen := make(map[uint64]bool)
for i := 0; i < 3; i++ {
batch, selected := client.buildNextBatch()
connections := client.socksConnections.Snapshot()
batch, selected := client.buildNextBatch(connections, queuedBytesAcross(connections))
if len(batch.Packets) != 1 || len(selected) != 1 {
t.Fatalf("iteration %d: expected one selected packet, got packets=%d selected=%d", i, len(batch.Packets), len(selected))
}
@@ -188,3 +193,68 @@ func TestBuildNextBatchRotatesAcrossConnections(t *testing.T) {
t.Fatalf("expected all 3 socks connections to be selected once, got %d unique selections", len(seen))
}
}
func TestBuildNextBatchHonorsPerSOCKSPacketLimit(t *testing.T) {
cfg := config.Config{
MaxChunkSize: 1024,
MaxPacketsPerBatch: 4,
MaxBatchBytes: 4096,
WorkerCount: 2,
MaxConcurrentBatches: 2,
MaxPacketsPerSOCKSPerBatch: 1,
MuxRotateEveryBatches: 1,
MuxBurstThresholdBytes: 1024,
MaxQueueBytesPerSOCKS: 4096,
HTTPBatchRandomize: false,
}
client := New(cfg, nil)
client.chunkPolicy = newChunkPolicy(cfg)
conn1 := client.socksConnections.New(client.clientSessionKey, "127.0.0.1:1001", client.chunkPolicy)
conn2 := client.socksConnections.New(client.clientSessionKey, "127.0.0.1:1002", client.chunkPolicy)
for i := 0; i < 3; i++ {
if err := conn1.EnqueuePacket(conn1.BuildSOCKSDataPacket([]byte("a"), false)); err != nil {
t.Fatalf("enqueue conn1 packet %d: %v", i, err)
}
}
if err := conn2.EnqueuePacket(conn2.BuildSOCKSDataPacket([]byte("b"), false)); err != nil {
t.Fatalf("enqueue conn2 packet: %v", err)
}
connections := client.socksConnections.Snapshot()
batch, selected := client.buildNextBatch(connections, queuedBytesAcross(connections))
if len(batch.Packets) != 2 || len(selected) != 2 {
t.Fatalf("expected 2 selected packets, got packets=%d selected=%d", len(batch.Packets), len(selected))
}
counts := map[uint64]int{}
for _, packet := range batch.Packets {
counts[packet.SOCKSID]++
}
if counts[conn1.ID] != 1 {
t.Fatalf("expected conn1 to contribute exactly 1 packet, got %d", counts[conn1.ID])
}
if counts[conn2.ID] != 1 {
t.Fatalf("expected conn2 to contribute exactly 1 packet, got %d", counts[conn2.ID])
}
}
func TestEffectiveConcurrentBatchesUsesBurstThreshold(t *testing.T) {
cfg := config.Config{
WorkerCount: 4,
MaxConcurrentBatches: 3,
MaxPacketsPerSOCKSPerBatch: 2,
MuxRotateEveryBatches: 1,
MuxBurstThresholdBytes: 4096,
}
client := New(cfg, nil)
if got := client.effectiveConcurrentBatches(1024); got != 1 {
t.Fatalf("expected low-load concurrency of 1, got %d", got)
}
if got := client.effectiveConcurrentBatches(4096); got != 3 {
t.Fatalf("expected burst concurrency of 3, got %d", got)
}
}
+120 -69
View File
@@ -16,79 +16,87 @@ import (
)
type Config struct {
AESEncryptionKey string
RelayURL string
HTTPUserAgentsFile string
HTTPHeaderProfile string
HTTPRandomizeHeaders bool
HTTPPaddingHeader string
HTTPPaddingMinBytes int
HTTPPaddingMaxBytes int
HTTPReferer string
HTTPAcceptLanguage string
HTTPTimingJitterMS int
HTTPBatchRandomize bool
HTTPBatchPacketsJitter int
HTTPBatchBytesJitter int
ServerHost string
ServerPort int
SOCKSHost string
SOCKSPort int
SOCKSAuth bool
SOCKSUsername string
SOCKSPassword string
LogLevel string
MaxChunkSize int
MaxPacketsPerBatch int
MaxBatchBytes int
WorkerCount int
HTTPRequestTimeoutMS int
WorkerPollIntervalMS int
IdlePollIntervalMS int
MaxQueueBytesPerSOCKS int
AckTimeoutMS int
MaxRetryCount int
ReorderTimeoutMS int
MaxReorderBufferPackets int
SessionIdleTimeoutMS int
SOCKSIdleTimeoutMS int
ReadBodyLimitBytes int
MaxServerQueueBytes int
AESEncryptionKey string
RelayURL string
HTTPUserAgentsFile string
HTTPHeaderProfile string
HTTPRandomizeHeaders bool
HTTPPaddingHeader string
HTTPPaddingMinBytes int
HTTPPaddingMaxBytes int
HTTPReferer string
HTTPAcceptLanguage string
HTTPTimingJitterMS int
HTTPBatchRandomize bool
HTTPBatchPacketsJitter int
HTTPBatchBytesJitter int
ServerHost string
ServerPort int
SOCKSHost string
SOCKSPort int
SOCKSAuth bool
SOCKSUsername string
SOCKSPassword string
LogLevel string
MaxChunkSize int
MaxPacketsPerBatch int
MaxBatchBytes int
WorkerCount int
MaxConcurrentBatches int
MaxPacketsPerSOCKSPerBatch int
MuxRotateEveryBatches int
MuxBurstThresholdBytes int
HTTPRequestTimeoutMS int
WorkerPollIntervalMS int
IdlePollIntervalMS int
MaxQueueBytesPerSOCKS int
AckTimeoutMS int
MaxRetryCount int
ReorderTimeoutMS int
MaxReorderBufferPackets int
SessionIdleTimeoutMS int
SOCKSIdleTimeoutMS int
ReadBodyLimitBytes int
MaxServerQueueBytes int
}
func Load(path string) (Config, error) {
cfg := Config{
SOCKSHost: "127.0.0.1",
SOCKSPort: 1080,
HTTPUserAgentsFile: "user-agents.txt",
HTTPHeaderProfile: "browser",
HTTPRandomizeHeaders: true,
HTTPPaddingHeader: "X-Padding",
HTTPPaddingMinBytes: 16,
HTTPPaddingMaxBytes: 48,
HTTPTimingJitterMS: 50,
HTTPBatchRandomize: true,
HTTPBatchPacketsJitter: 4,
HTTPBatchBytesJitter: 32768,
ServerHost: "127.0.0.1",
ServerPort: 28080,
LogLevel: "INFO",
MaxChunkSize: 16 * 1024,
MaxPacketsPerBatch: 32,
MaxBatchBytes: 256 * 1024,
WorkerCount: 4,
HTTPRequestTimeoutMS: 15000,
WorkerPollIntervalMS: 200,
IdlePollIntervalMS: 1000,
MaxQueueBytesPerSOCKS: 1024 * 1024,
AckTimeoutMS: 5000,
MaxRetryCount: 5,
ReorderTimeoutMS: 5000,
MaxReorderBufferPackets: 128,
SessionIdleTimeoutMS: 5 * 60 * 1000,
SOCKSIdleTimeoutMS: 2 * 60 * 1000,
ReadBodyLimitBytes: 2 * 1024 * 1024,
MaxServerQueueBytes: 2 * 1024 * 1024,
SOCKSHost: "127.0.0.1",
SOCKSPort: 1080,
HTTPUserAgentsFile: "user-agents.txt",
HTTPHeaderProfile: "browser",
HTTPRandomizeHeaders: true,
HTTPPaddingHeader: "X-Padding",
HTTPPaddingMinBytes: 16,
HTTPPaddingMaxBytes: 48,
HTTPTimingJitterMS: 50,
HTTPBatchRandomize: true,
HTTPBatchPacketsJitter: 4,
HTTPBatchBytesJitter: 32768,
ServerHost: "127.0.0.1",
ServerPort: 28080,
LogLevel: "INFO",
MaxChunkSize: 16 * 1024,
MaxPacketsPerBatch: 32,
MaxBatchBytes: 256 * 1024,
WorkerCount: 4,
MaxConcurrentBatches: 4,
MaxPacketsPerSOCKSPerBatch: 2,
MuxRotateEveryBatches: 1,
MuxBurstThresholdBytes: 128 * 1024,
HTTPRequestTimeoutMS: 15000,
WorkerPollIntervalMS: 200,
IdlePollIntervalMS: 1000,
MaxQueueBytesPerSOCKS: 1024 * 1024,
AckTimeoutMS: 5000,
MaxRetryCount: 5,
ReorderTimeoutMS: 5000,
MaxReorderBufferPackets: 128,
SessionIdleTimeoutMS: 5 * 60 * 1000,
SOCKSIdleTimeoutMS: 2 * 60 * 1000,
ReadBodyLimitBytes: 2 * 1024 * 1024,
MaxServerQueueBytes: 2 * 1024 * 1024,
}
file, err := os.Open(path)
@@ -235,6 +243,34 @@ func Load(path string) (Config, error) {
}
cfg.WorkerCount = count
case "MAX_CONCURRENT_BATCHES":
count, err := strconv.Atoi(value)
if err != nil {
return Config{}, fmt.Errorf("parse MAX_CONCURRENT_BATCHES: %w", err)
}
cfg.MaxConcurrentBatches = count
case "MAX_PACKETS_PER_SOCKS_PER_BATCH":
count, err := strconv.Atoi(value)
if err != nil {
return Config{}, fmt.Errorf("parse MAX_PACKETS_PER_SOCKS_PER_BATCH: %w", err)
}
cfg.MaxPacketsPerSOCKSPerBatch = count
case "MUX_ROTATE_EVERY_BATCHES":
count, err := strconv.Atoi(value)
if err != nil {
return Config{}, fmt.Errorf("parse MUX_ROTATE_EVERY_BATCHES: %w", err)
}
cfg.MuxRotateEveryBatches = count
case "MUX_BURST_THRESHOLD_BYTES":
size, err := strconv.Atoi(value)
if err != nil {
return Config{}, fmt.Errorf("parse MUX_BURST_THRESHOLD_BYTES: %w", err)
}
cfg.MuxBurstThresholdBytes = size
case "HTTP_REQUEST_TIMEOUT_MS":
timeout, err := strconv.Atoi(value)
if err != nil {
@@ -348,6 +384,21 @@ func (c Config) ValidateClient() error {
if c.HTTPRequestTimeoutMS < 1 {
return fmt.Errorf("invalid HTTP_REQUEST_TIMEOUT_MS: %d", c.HTTPRequestTimeoutMS)
}
if c.MaxConcurrentBatches < 1 {
return fmt.Errorf("invalid MAX_CONCURRENT_BATCHES: %d", c.MaxConcurrentBatches)
}
if c.MaxConcurrentBatches > c.WorkerCount {
return fmt.Errorf("MAX_CONCURRENT_BATCHES must be <= WORKER_COUNT")
}
if c.MaxPacketsPerSOCKSPerBatch < 1 {
return fmt.Errorf("invalid MAX_PACKETS_PER_SOCKS_PER_BATCH: %d", c.MaxPacketsPerSOCKSPerBatch)
}
if c.MuxRotateEveryBatches < 1 {
return fmt.Errorf("invalid MUX_ROTATE_EVERY_BATCHES: %d", c.MuxRotateEveryBatches)
}
if c.MuxBurstThresholdBytes < c.MaxChunkSize {
return fmt.Errorf("MUX_BURST_THRESHOLD_BYTES must be >= MAX_CHUNK_SIZE")
}
if c.WorkerPollIntervalMS < 1 {
return fmt.Errorf("invalid WORKER_POLL_INTERVAL_MS: %d", c.WorkerPollIntervalMS)