Improve queue cleanup, batch fairness, and worker CPU efficiency

This commit is contained in:
Amin.MasterkinG
2026-04-20 20:29:23 +03:30
parent c4776c88e1
commit 1c4dd0138b
5 changed files with 117 additions and 9 deletions
+10
View File
@@ -29,8 +29,10 @@ type Client struct {
connMu sync.Mutex
conns map[net.Conn]struct{}
workCh chan struct{}
lastPollUnixMS atomic.Int64
batchCursor atomic.Uint64
}
func New(cfg config.Config, lg *logger.Logger) *Client {
@@ -43,6 +45,7 @@ func New(cfg config.Config, lg *logger.Logger) *Client {
socksConnections: NewSOCKSConnectionStore(),
chunkPolicy: newChunkPolicy(cfg),
conns: make(map[net.Conn]struct{}),
workCh: make(chan struct{}, 1),
}
}
@@ -118,6 +121,13 @@ func (c *Client) closeAllConns() {
}
}
func (c *Client) signalSendWork() {
select {
case c.workCh <- struct{}{}:
default:
}
}
func generateClientSessionKey() string {
now := time.Now().UTC().Format("20060102T150405.000000000Z")
random := make([]byte, 16)
+34 -5
View File
@@ -58,14 +58,14 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
c.reclaimExpiredInFlight()
batch, selected := c.buildNextBatch()
if len(batch.Packets) == 0 {
time.Sleep(pollInterval)
c.waitForSendWork(ctx, pollInterval)
continue
}
if err := batch.Validate(); err != nil {
c.log.Errorf("<red>worker=<cyan>%d</cyan> invalid batch: <cyan>%v</cyan></red>", w.id, err)
c.requeueSelected(selected)
time.Sleep(pollInterval)
c.waitForSendWork(ctx, pollInterval)
continue
}
@@ -75,14 +75,14 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
if err != nil {
c.log.Errorf("<red>worker=<cyan>%d</cyan> encrypt batch failed: <cyan>%v</cyan></red>", w.id, err)
c.requeueSelected(selected)
time.Sleep(pollInterval)
c.waitForSendWork(ctx, pollInterval)
continue
}
if err := w.postBatch(ctx, c, batch, body); err != nil {
c.log.Warnf("<yellow>worker=<cyan>%d</cyan> send failed for batch=<cyan>%s</cyan>: <cyan>%v</cyan></yellow>", w.id, batch.BatchID, err)
c.requeueSelected(selected)
time.Sleep(pollInterval)
c.waitForSendWork(ctx, pollInterval)
continue
}
@@ -93,8 +93,28 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
}
}
func (c *Client) waitForSendWork(ctx context.Context, interval time.Duration) {
timer := time.NewTimer(interval)
defer timer.Stop()
select {
case <-ctx.Done():
case <-c.workCh:
case <-timer.C:
}
}
func (c *Client) buildNextBatch() (protocol.Batch, []dequeuedPacket) {
connections := c.socksConnections.Snapshot()
if len(connections) == 0 {
return protocol.Batch{}, nil
}
start := 0
if len(connections) > 1 {
start = int(c.batchCursor.Add(1)-1) % len(connections)
}
selected := make([]dequeuedPacket, 0, c.cfg.MaxPacketsPerBatch)
packets := make([]protocol.Packet, 0, c.cfg.MaxPacketsPerBatch)
totalBytes := 0
@@ -102,11 +122,13 @@ func (c *Client) buildNextBatch() (protocol.Batch, []dequeuedPacket) {
for len(selected) < c.cfg.MaxPacketsPerBatch {
progress := false
for _, socksConn := range connections {
for offset := 0; offset < len(connections); offset++ {
if len(selected) >= c.cfg.MaxPacketsPerBatch {
break
}
socksConn := connections[(start+offset)%len(connections)]
item := socksConn.DequeuePacket()
if item == nil {
continue
@@ -175,6 +197,9 @@ func (c *Client) requeueSelected(selected []dequeuedPacket) {
for socksConn, identityKeys := range grouped {
socksConn.RequeueInFlightByIdentity(identityKeys)
}
if len(grouped) > 0 {
c.signalSendWork()
}
}
func (c *Client) markSelectedInFlight(selected []dequeuedPacket) {
@@ -196,9 +221,13 @@ func (c *Client) reclaimExpiredInFlight() {
"<yellow>socks_id=<cyan>%d</cyan> reclaimed inflight requeued=<cyan>%d</cyan> dropped=<cyan>%d</cyan></yellow>",
socksConn.ID, requeued, dropped,
)
if requeued > 0 {
c.signalSendWork()
}
if dropped > 0 {
socksConn.ConnectFailure = "max retry exceeded"
socksConn.CompleteConnect(fmt.Errorf("max retry exceeded"))
socksConn.ResetTransportState()
_ = socksConn.CloseLocal()
}
}
+21
View File
@@ -185,6 +185,20 @@ func (s *SOCKSConnection) WaitUntilClosed(ctx context.Context) {
}
}
func (s *SOCKSConnection) ResetTransportState() {
s.queueMu.Lock()
for i := range s.OutboundQueue {
s.OutboundQueue[i] = nil
}
s.OutboundQueue = nil
s.QueuedBytes = 0
clear(s.InFlight)
s.queueMu.Unlock()
s.InitialPayload = nil
s.BufferedBytes = 0
}
func (s *SOCKSConnectionStore) Get(id uint64) *SOCKSConnection {
s.mu.RLock()
defer s.mu.RUnlock()
@@ -193,8 +207,14 @@ func (s *SOCKSConnectionStore) Get(id uint64) *SOCKSConnection {
func (s *SOCKSConnectionStore) Delete(id uint64) {
s.mu.Lock()
item := s.items[id]
delete(s.items, id)
s.mu.Unlock()
if item != nil {
item.ResetTransportState()
_ = item.CloseLocal()
}
}
func (s *SOCKSConnectionStore) CloseAll() {
@@ -207,6 +227,7 @@ func (s *SOCKSConnectionStore) CloseAll() {
s.mu.Unlock()
for _, item := range items {
item.ResetTransportState()
_ = item.CloseLocal()
}
}
+13 -2
View File
@@ -101,6 +101,7 @@ func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn, socksConn *SOC
if err := socksConn.EnqueuePacket(socksConn.BuildSOCKSConnectPacket()); err != nil {
return err
}
c.signalSendWork()
if err := socksConn.WaitForConnect(ctx, 30*time.Second); err != nil {
_ = writeSocksReply(conn, socksReplyGeneralFailure)
@@ -279,6 +280,9 @@ func (c *Client) captureInitialPayload(ctx context.Context, conn net.Conn, socks
if enqueueErr != nil {
return enqueueErr
}
if enqueued > 0 {
c.signalSendWork()
}
} else if ne, ok := err.(net.Error); !ok || !ne.Timeout() {
if errors.Is(err, io.EOF) {
return nil
@@ -312,19 +316,26 @@ func (c *Client) captureInitialPayload(ctx context.Context, conn net.Conn, socks
if enqueueErr != nil {
return enqueueErr
}
if enqueued > 0 {
c.signalSendWork()
}
}
if err != nil {
if errors.Is(err, io.EOF) {
socksConn.MarkLocalReadEOF()
_ = socksConn.EnqueuePacket(socksConn.BuildSOCKSCloseWritePacket())
if enqueueErr := socksConn.EnqueuePacket(socksConn.BuildSOCKSCloseWritePacket()); enqueueErr == nil {
c.signalSendWork()
}
socksConn.WaitUntilClosed(ctx)
return nil
}
if ne, ok := err.(net.Error); ok && ne.Timeout() {
continue
}
_ = socksConn.EnqueuePacket(socksConn.BuildSOCKSRSTPacket())
if enqueueErr := socksConn.EnqueuePacket(socksConn.BuildSOCKSRSTPacket()); enqueueErr == nil {
c.signalSendWork()
}
return err
}
}