mirror of
https://github.com/masterking32/MasterHttpRelayVPN.git
synced 2026-05-18 06:44:35 +03:00
Improve queue cleanup, batch fairness, and worker CPU efficiency
This commit is contained in:
@@ -29,8 +29,10 @@ type Client struct {
|
||||
|
||||
connMu sync.Mutex
|
||||
conns map[net.Conn]struct{}
|
||||
workCh chan struct{}
|
||||
|
||||
lastPollUnixMS atomic.Int64
|
||||
batchCursor atomic.Uint64
|
||||
}
|
||||
|
||||
func New(cfg config.Config, lg *logger.Logger) *Client {
|
||||
@@ -43,6 +45,7 @@ func New(cfg config.Config, lg *logger.Logger) *Client {
|
||||
socksConnections: NewSOCKSConnectionStore(),
|
||||
chunkPolicy: newChunkPolicy(cfg),
|
||||
conns: make(map[net.Conn]struct{}),
|
||||
workCh: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,6 +121,13 @@ func (c *Client) closeAllConns() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) signalSendWork() {
|
||||
select {
|
||||
case c.workCh <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func generateClientSessionKey() string {
|
||||
now := time.Now().UTC().Format("20060102T150405.000000000Z")
|
||||
random := make([]byte, 16)
|
||||
|
||||
@@ -58,14 +58,14 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
|
||||
c.reclaimExpiredInFlight()
|
||||
batch, selected := c.buildNextBatch()
|
||||
if len(batch.Packets) == 0 {
|
||||
time.Sleep(pollInterval)
|
||||
c.waitForSendWork(ctx, pollInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := batch.Validate(); err != nil {
|
||||
c.log.Errorf("<red>worker=<cyan>%d</cyan> invalid batch: <cyan>%v</cyan></red>", w.id, err)
|
||||
c.requeueSelected(selected)
|
||||
time.Sleep(pollInterval)
|
||||
c.waitForSendWork(ctx, pollInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -75,14 +75,14 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
|
||||
if err != nil {
|
||||
c.log.Errorf("<red>worker=<cyan>%d</cyan> encrypt batch failed: <cyan>%v</cyan></red>", w.id, err)
|
||||
c.requeueSelected(selected)
|
||||
time.Sleep(pollInterval)
|
||||
c.waitForSendWork(ctx, pollInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := w.postBatch(ctx, c, batch, body); err != nil {
|
||||
c.log.Warnf("<yellow>worker=<cyan>%d</cyan> send failed for batch=<cyan>%s</cyan>: <cyan>%v</cyan></yellow>", w.id, batch.BatchID, err)
|
||||
c.requeueSelected(selected)
|
||||
time.Sleep(pollInterval)
|
||||
c.waitForSendWork(ctx, pollInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -93,8 +93,28 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) waitForSendWork(ctx context.Context, interval time.Duration) {
|
||||
timer := time.NewTimer(interval)
|
||||
defer timer.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-c.workCh:
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) buildNextBatch() (protocol.Batch, []dequeuedPacket) {
|
||||
connections := c.socksConnections.Snapshot()
|
||||
if len(connections) == 0 {
|
||||
return protocol.Batch{}, nil
|
||||
}
|
||||
|
||||
start := 0
|
||||
if len(connections) > 1 {
|
||||
start = int(c.batchCursor.Add(1)-1) % len(connections)
|
||||
}
|
||||
|
||||
selected := make([]dequeuedPacket, 0, c.cfg.MaxPacketsPerBatch)
|
||||
packets := make([]protocol.Packet, 0, c.cfg.MaxPacketsPerBatch)
|
||||
totalBytes := 0
|
||||
@@ -102,11 +122,13 @@ func (c *Client) buildNextBatch() (protocol.Batch, []dequeuedPacket) {
|
||||
for len(selected) < c.cfg.MaxPacketsPerBatch {
|
||||
progress := false
|
||||
|
||||
for _, socksConn := range connections {
|
||||
for offset := 0; offset < len(connections); offset++ {
|
||||
if len(selected) >= c.cfg.MaxPacketsPerBatch {
|
||||
break
|
||||
}
|
||||
|
||||
socksConn := connections[(start+offset)%len(connections)]
|
||||
|
||||
item := socksConn.DequeuePacket()
|
||||
if item == nil {
|
||||
continue
|
||||
@@ -175,6 +197,9 @@ func (c *Client) requeueSelected(selected []dequeuedPacket) {
|
||||
for socksConn, identityKeys := range grouped {
|
||||
socksConn.RequeueInFlightByIdentity(identityKeys)
|
||||
}
|
||||
if len(grouped) > 0 {
|
||||
c.signalSendWork()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) markSelectedInFlight(selected []dequeuedPacket) {
|
||||
@@ -196,9 +221,13 @@ func (c *Client) reclaimExpiredInFlight() {
|
||||
"<yellow>socks_id=<cyan>%d</cyan> reclaimed inflight requeued=<cyan>%d</cyan> dropped=<cyan>%d</cyan></yellow>",
|
||||
socksConn.ID, requeued, dropped,
|
||||
)
|
||||
if requeued > 0 {
|
||||
c.signalSendWork()
|
||||
}
|
||||
if dropped > 0 {
|
||||
socksConn.ConnectFailure = "max retry exceeded"
|
||||
socksConn.CompleteConnect(fmt.Errorf("max retry exceeded"))
|
||||
socksConn.ResetTransportState()
|
||||
_ = socksConn.CloseLocal()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,6 +185,20 @@ func (s *SOCKSConnection) WaitUntilClosed(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SOCKSConnection) ResetTransportState() {
|
||||
s.queueMu.Lock()
|
||||
for i := range s.OutboundQueue {
|
||||
s.OutboundQueue[i] = nil
|
||||
}
|
||||
s.OutboundQueue = nil
|
||||
s.QueuedBytes = 0
|
||||
clear(s.InFlight)
|
||||
s.queueMu.Unlock()
|
||||
|
||||
s.InitialPayload = nil
|
||||
s.BufferedBytes = 0
|
||||
}
|
||||
|
||||
func (s *SOCKSConnectionStore) Get(id uint64) *SOCKSConnection {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
@@ -193,8 +207,14 @@ func (s *SOCKSConnectionStore) Get(id uint64) *SOCKSConnection {
|
||||
|
||||
func (s *SOCKSConnectionStore) Delete(id uint64) {
|
||||
s.mu.Lock()
|
||||
item := s.items[id]
|
||||
delete(s.items, id)
|
||||
s.mu.Unlock()
|
||||
|
||||
if item != nil {
|
||||
item.ResetTransportState()
|
||||
_ = item.CloseLocal()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SOCKSConnectionStore) CloseAll() {
|
||||
@@ -207,6 +227,7 @@ func (s *SOCKSConnectionStore) CloseAll() {
|
||||
s.mu.Unlock()
|
||||
|
||||
for _, item := range items {
|
||||
item.ResetTransportState()
|
||||
_ = item.CloseLocal()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,6 +101,7 @@ func (c *Client) handleSOCKS5(ctx context.Context, conn net.Conn, socksConn *SOC
|
||||
if err := socksConn.EnqueuePacket(socksConn.BuildSOCKSConnectPacket()); err != nil {
|
||||
return err
|
||||
}
|
||||
c.signalSendWork()
|
||||
|
||||
if err := socksConn.WaitForConnect(ctx, 30*time.Second); err != nil {
|
||||
_ = writeSocksReply(conn, socksReplyGeneralFailure)
|
||||
@@ -279,6 +280,9 @@ func (c *Client) captureInitialPayload(ctx context.Context, conn net.Conn, socks
|
||||
if enqueueErr != nil {
|
||||
return enqueueErr
|
||||
}
|
||||
if enqueued > 0 {
|
||||
c.signalSendWork()
|
||||
}
|
||||
} else if ne, ok := err.(net.Error); !ok || !ne.Timeout() {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return nil
|
||||
@@ -312,19 +316,26 @@ func (c *Client) captureInitialPayload(ctx context.Context, conn net.Conn, socks
|
||||
if enqueueErr != nil {
|
||||
return enqueueErr
|
||||
}
|
||||
if enqueued > 0 {
|
||||
c.signalSendWork()
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
socksConn.MarkLocalReadEOF()
|
||||
_ = socksConn.EnqueuePacket(socksConn.BuildSOCKSCloseWritePacket())
|
||||
if enqueueErr := socksConn.EnqueuePacket(socksConn.BuildSOCKSCloseWritePacket()); enqueueErr == nil {
|
||||
c.signalSendWork()
|
||||
}
|
||||
socksConn.WaitUntilClosed(ctx)
|
||||
return nil
|
||||
}
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
continue
|
||||
}
|
||||
_ = socksConn.EnqueuePacket(socksConn.BuildSOCKSRSTPacket())
|
||||
if enqueueErr := socksConn.EnqueuePacket(socksConn.BuildSOCKSRSTPacket()); enqueueErr == nil {
|
||||
c.signalSendWork()
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user