mirror of
https://github.com/masterking32/MasterHttpRelayVPN.git
synced 2026-05-18 23:54:37 +03:00
Add in-flight ACK tracking and retry handling
This commit is contained in:
@@ -17,4 +17,6 @@ WORKER_COUNT = 4
|
|||||||
HTTP_REQUEST_TIMEOUT_MS = 15000
|
HTTP_REQUEST_TIMEOUT_MS = 15000
|
||||||
WORKER_POLL_INTERVAL_MS = 200
|
WORKER_POLL_INTERVAL_MS = 200
|
||||||
MAX_QUEUE_BYTES_PER_SOCKS = 1048576
|
MAX_QUEUE_BYTES_PER_SOCKS = 1048576
|
||||||
|
ACK_TIMEOUT_MS = 5000
|
||||||
|
MAX_RETRY_COUNT = 5
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
|
|||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.reclaimExpiredInFlight()
|
||||||
batch, selected := c.buildNextBatch()
|
batch, selected := c.buildNextBatch()
|
||||||
if len(batch.Packets) == 0 {
|
if len(batch.Packets) == 0 {
|
||||||
time.Sleep(pollInterval)
|
time.Sleep(pollInterval)
|
||||||
@@ -68,6 +69,8 @@ func (w *sendWorker) run(ctx context.Context, c *Client) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.markSelectedInFlight(selected)
|
||||||
|
|
||||||
body, err := protocol.EncryptBatch(batch, c.cfg.AESEncryptionKey)
|
body, err := protocol.EncryptBatch(batch, c.cfg.AESEncryptionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("<red>worker=<cyan>%d</cyan> encrypt batch failed: <cyan>%v</cyan></red>", w.id, err)
|
c.log.Errorf("<red>worker=<cyan>%d</cyan> encrypt batch failed: <cyan>%v</cyan></red>", w.id, err)
|
||||||
@@ -164,13 +167,41 @@ func (c *Client) buildPollBatch(connections []*SOCKSConnection) (protocol.Batch,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) requeueSelected(selected []dequeuedPacket) {
|
func (c *Client) requeueSelected(selected []dequeuedPacket) {
|
||||||
|
grouped := make(map[*SOCKSConnection][]string)
|
||||||
|
for _, entry := range selected {
|
||||||
|
grouped[entry.socksConn] = append(grouped[entry.socksConn], entry.item.IdentityKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
for socksConn, identityKeys := range grouped {
|
||||||
|
socksConn.RequeueInFlightByIdentity(identityKeys)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) markSelectedInFlight(selected []dequeuedPacket) {
|
||||||
grouped := make(map[*SOCKSConnection][]*SOCKSOutboundQueueItem)
|
grouped := make(map[*SOCKSConnection][]*SOCKSOutboundQueueItem)
|
||||||
for _, entry := range selected {
|
for _, entry := range selected {
|
||||||
grouped[entry.socksConn] = append(grouped[entry.socksConn], entry.item)
|
grouped[entry.socksConn] = append(grouped[entry.socksConn], entry.item)
|
||||||
}
|
}
|
||||||
|
|
||||||
for socksConn, items := range grouped {
|
for socksConn, items := range grouped {
|
||||||
socksConn.RequeueFront(items)
|
socksConn.MarkInFlight(items)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) reclaimExpiredInFlight() {
|
||||||
|
ackTimeout := time.Duration(c.cfg.AckTimeoutMS) * time.Millisecond
|
||||||
|
for _, socksConn := range c.socksConnections.Snapshot() {
|
||||||
|
requeued, dropped := socksConn.ReclaimExpiredInFlight(ackTimeout, c.cfg.MaxRetryCount)
|
||||||
|
if requeued > 0 || dropped > 0 {
|
||||||
|
c.log.Warnf(
|
||||||
|
"<yellow>socks_id=<cyan>%d</cyan> reclaimed inflight requeued=<cyan>%d</cyan> dropped=<cyan>%d</cyan></yellow>",
|
||||||
|
socksConn.ID, requeued, dropped,
|
||||||
|
)
|
||||||
|
if dropped > 0 {
|
||||||
|
socksConn.ConnectFailure = "max retry exceeded"
|
||||||
|
socksConn.CompleteConnect(fmt.Errorf("max retry exceeded"))
|
||||||
|
_ = socksConn.CloseLocal()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,7 +258,7 @@ func (c *Client) applyResponseBatch(batch protocol.Batch) error {
|
|||||||
|
|
||||||
func (c *Client) applyResponsePacket(packet protocol.Packet) error {
|
func (c *Client) applyResponsePacket(packet protocol.Packet) error {
|
||||||
switch packet.Type {
|
switch packet.Type {
|
||||||
case protocol.PacketTypePing, protocol.PacketTypePong, protocol.PacketTypeSOCKSDataAck:
|
case protocol.PacketTypePing, protocol.PacketTypePong:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -238,6 +269,7 @@ func (c *Client) applyResponsePacket(packet protocol.Packet) error {
|
|||||||
|
|
||||||
switch packet.Type {
|
switch packet.Type {
|
||||||
case protocol.PacketTypeSOCKSConnectAck:
|
case protocol.PacketTypeSOCKSConnectAck:
|
||||||
|
_ = socksConn.AckPacket(packet)
|
||||||
socksConn.ConnectAccepted = true
|
socksConn.ConnectAccepted = true
|
||||||
socksConn.LastActivityAt = time.Now()
|
socksConn.LastActivityAt = time.Now()
|
||||||
socksConn.CompleteConnect(nil)
|
socksConn.CompleteConnect(nil)
|
||||||
@@ -257,16 +289,23 @@ func (c *Client) applyResponsePacket(packet protocol.Packet) error {
|
|||||||
if len(packet.Payload) > 0 {
|
if len(packet.Payload) > 0 {
|
||||||
message = string(packet.Payload)
|
message = string(packet.Payload)
|
||||||
}
|
}
|
||||||
|
_ = socksConn.AckPacket(packet)
|
||||||
socksConn.ConnectFailure = message
|
socksConn.ConnectFailure = message
|
||||||
socksConn.CompleteConnect(fmt.Errorf("%s", message))
|
socksConn.CompleteConnect(fmt.Errorf("%s", message))
|
||||||
_ = socksConn.CloseLocal()
|
_ = socksConn.CloseLocal()
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
|
case protocol.PacketTypeSOCKSDataAck:
|
||||||
|
_ = socksConn.AckPacket(packet)
|
||||||
|
socksConn.LastActivityAt = time.Now()
|
||||||
|
return nil
|
||||||
|
|
||||||
case protocol.PacketTypeSOCKSData:
|
case protocol.PacketTypeSOCKSData:
|
||||||
socksConn.LastActivityAt = time.Now()
|
socksConn.LastActivityAt = time.Now()
|
||||||
return socksConn.WriteToLocal(packet.Payload)
|
return socksConn.WriteToLocal(packet.Payload)
|
||||||
|
|
||||||
case protocol.PacketTypeSOCKSCloseRead, protocol.PacketTypeSOCKSCloseWrite, protocol.PacketTypeSOCKSRST:
|
case protocol.PacketTypeSOCKSCloseRead, protocol.PacketTypeSOCKSCloseWrite, protocol.PacketTypeSOCKSRST:
|
||||||
|
_ = socksConn.AckPacket(packet)
|
||||||
socksConn.LastActivityAt = time.Now()
|
socksConn.LastActivityAt = time.Now()
|
||||||
return socksConn.CloseLocal()
|
return socksConn.CloseLocal()
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ type SOCKSConnection struct {
|
|||||||
queueMu sync.Mutex
|
queueMu sync.Mutex
|
||||||
OutboundQueue []*SOCKSOutboundQueueItem
|
OutboundQueue []*SOCKSOutboundQueueItem
|
||||||
QueuedBytes int
|
QueuedBytes int
|
||||||
|
InFlight map[string]*SOCKSOutboundQueueItem
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SOCKSConnection) InitialPayloadHex() string {
|
func (s *SOCKSConnection) InitialPayloadHex() string {
|
||||||
@@ -74,6 +75,7 @@ func (s *SOCKSConnectionStore) New(clientSessionKey string, clientAddress string
|
|||||||
LastActivityAt: now,
|
LastActivityAt: now,
|
||||||
ClientAddress: clientAddress,
|
ClientAddress: clientAddress,
|
||||||
connectResultC: make(chan error, 1),
|
connectResultC: make(chan error, 1),
|
||||||
|
InFlight: make(map[string]*SOCKSOutboundQueueItem),
|
||||||
}
|
}
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
|
|||||||
@@ -39,7 +39,9 @@ type SOCKSOutboundQueueItem struct {
|
|||||||
IdentityKey string
|
IdentityKey string
|
||||||
Packet protocol.Packet
|
Packet protocol.Packet
|
||||||
QueuedAt time.Time
|
QueuedAt time.Time
|
||||||
|
SentAt time.Time
|
||||||
PayloadSize int
|
PayloadSize int
|
||||||
|
RetryCount int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SOCKSConnection) EnqueuePacket(packet protocol.Packet) error {
|
func (s *SOCKSConnection) EnqueuePacket(packet protocol.Packet) error {
|
||||||
@@ -136,6 +138,127 @@ func (s *SOCKSConnection) RequeueFront(items []*SOCKSOutboundQueueItem) {
|
|||||||
s.OutboundQueue = front
|
s.OutboundQueue = front
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *SOCKSConnection) MarkInFlight(items []*SOCKSOutboundQueueItem) {
|
||||||
|
if len(items) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.queueMu.Lock()
|
||||||
|
defer s.queueMu.Unlock()
|
||||||
|
|
||||||
|
for _, item := range items {
|
||||||
|
if item == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
item.SentAt = time.Now()
|
||||||
|
s.InFlight[item.IdentityKey] = item
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SOCKSConnection) AckPacket(packet protocol.Packet) bool {
|
||||||
|
identityKey := protocol.PacketIdentityKey(
|
||||||
|
packet.ClientSessionKey,
|
||||||
|
packet.SOCKSID,
|
||||||
|
ackTargetPacketType(packet.Type),
|
||||||
|
packet.Sequence,
|
||||||
|
packet.FragmentID,
|
||||||
|
)
|
||||||
|
|
||||||
|
s.queueMu.Lock()
|
||||||
|
defer s.queueMu.Unlock()
|
||||||
|
if _, ok := s.InFlight[identityKey]; ok {
|
||||||
|
delete(s.InFlight, identityKey)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SOCKSConnection) RequeueInFlightByIdentity(identityKeys []string) {
|
||||||
|
if len(identityKeys) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s.queueMu.Lock()
|
||||||
|
defer s.queueMu.Unlock()
|
||||||
|
|
||||||
|
front := make([]*SOCKSOutboundQueueItem, 0, len(identityKeys)+len(s.OutboundQueue))
|
||||||
|
for _, identityKey := range identityKeys {
|
||||||
|
item, ok := s.InFlight[identityKey]
|
||||||
|
if !ok || item == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
delete(s.InFlight, identityKey)
|
||||||
|
item.SentAt = time.Time{}
|
||||||
|
front = append(front, item)
|
||||||
|
s.QueuedBytes += item.PayloadSize
|
||||||
|
}
|
||||||
|
front = append(front, s.OutboundQueue...)
|
||||||
|
s.OutboundQueue = front
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SOCKSConnection) ReclaimExpiredInFlight(ackTimeout time.Duration, maxRetryCount int) (requeued int, dropped int) {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
s.queueMu.Lock()
|
||||||
|
defer s.queueMu.Unlock()
|
||||||
|
|
||||||
|
if len(s.InFlight) == 0 {
|
||||||
|
return 0, 0
|
||||||
|
}
|
||||||
|
|
||||||
|
front := make([]*SOCKSOutboundQueueItem, 0, len(s.InFlight)+len(s.OutboundQueue))
|
||||||
|
for identityKey, item := range s.InFlight {
|
||||||
|
if item == nil || item.SentAt.IsZero() || now.Sub(item.SentAt) < ackTimeout {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(s.InFlight, identityKey)
|
||||||
|
if item.RetryCount >= maxRetryCount {
|
||||||
|
dropped++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
item.RetryCount++
|
||||||
|
item.SentAt = time.Time{}
|
||||||
|
front = append(front, item)
|
||||||
|
s.QueuedBytes += item.PayloadSize
|
||||||
|
requeued++
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(front) > 0 {
|
||||||
|
front = append(front, s.OutboundQueue...)
|
||||||
|
s.OutboundQueue = front
|
||||||
|
}
|
||||||
|
return requeued, dropped
|
||||||
|
}
|
||||||
|
|
||||||
|
func ackTargetPacketType(packetType protocol.PacketType) protocol.PacketType {
|
||||||
|
switch packetType {
|
||||||
|
case protocol.PacketTypeSOCKSConnectAck,
|
||||||
|
protocol.PacketTypeSOCKSConnectFail,
|
||||||
|
protocol.PacketTypeSOCKSRuleSetDenied,
|
||||||
|
protocol.PacketTypeSOCKSNetworkUnreachable,
|
||||||
|
protocol.PacketTypeSOCKSHostUnreachable,
|
||||||
|
protocol.PacketTypeSOCKSConnectionRefused,
|
||||||
|
protocol.PacketTypeSOCKSTTLExpired,
|
||||||
|
protocol.PacketTypeSOCKSCommandUnsupported,
|
||||||
|
protocol.PacketTypeSOCKSAddressTypeUnsupported,
|
||||||
|
protocol.PacketTypeSOCKSAuthFailed,
|
||||||
|
protocol.PacketTypeSOCKSUpstreamUnavailable:
|
||||||
|
return protocol.PacketTypeSOCKSConnect
|
||||||
|
case protocol.PacketTypeSOCKSDataAck:
|
||||||
|
return protocol.PacketTypeSOCKSData
|
||||||
|
case protocol.PacketTypeSOCKSCloseRead:
|
||||||
|
return protocol.PacketTypeSOCKSCloseRead
|
||||||
|
case protocol.PacketTypeSOCKSCloseWrite:
|
||||||
|
return protocol.PacketTypeSOCKSCloseWrite
|
||||||
|
case protocol.PacketTypeSOCKSRST:
|
||||||
|
return protocol.PacketTypeSOCKSRST
|
||||||
|
default:
|
||||||
|
return packetType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func splitPayloadChunks(payload []byte, maxChunkSize int) [][]byte {
|
func splitPayloadChunks(payload []byte, maxChunkSize int) [][]byte {
|
||||||
if len(payload) == 0 || maxChunkSize <= 0 {
|
if len(payload) == 0 || maxChunkSize <= 0 {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ type Config struct {
|
|||||||
HTTPRequestTimeoutMS int
|
HTTPRequestTimeoutMS int
|
||||||
WorkerPollIntervalMS int
|
WorkerPollIntervalMS int
|
||||||
MaxQueueBytesPerSOCKS int
|
MaxQueueBytesPerSOCKS int
|
||||||
|
AckTimeoutMS int
|
||||||
|
MaxRetryCount int
|
||||||
SessionIdleTimeoutMS int
|
SessionIdleTimeoutMS int
|
||||||
SOCKSIdleTimeoutMS int
|
SOCKSIdleTimeoutMS int
|
||||||
ReadBodyLimitBytes int
|
ReadBodyLimitBytes int
|
||||||
@@ -52,6 +54,8 @@ func Load(path string) (Config, error) {
|
|||||||
HTTPRequestTimeoutMS: 15000,
|
HTTPRequestTimeoutMS: 15000,
|
||||||
WorkerPollIntervalMS: 200,
|
WorkerPollIntervalMS: 200,
|
||||||
MaxQueueBytesPerSOCKS: 1024 * 1024,
|
MaxQueueBytesPerSOCKS: 1024 * 1024,
|
||||||
|
AckTimeoutMS: 5000,
|
||||||
|
MaxRetryCount: 5,
|
||||||
SessionIdleTimeoutMS: 5 * 60 * 1000,
|
SessionIdleTimeoutMS: 5 * 60 * 1000,
|
||||||
SOCKSIdleTimeoutMS: 2 * 60 * 1000,
|
SOCKSIdleTimeoutMS: 2 * 60 * 1000,
|
||||||
ReadBodyLimitBytes: 2 * 1024 * 1024,
|
ReadBodyLimitBytes: 2 * 1024 * 1024,
|
||||||
@@ -153,6 +157,18 @@ func Load(path string) (Config, error) {
|
|||||||
return Config{}, fmt.Errorf("parse MAX_QUEUE_BYTES_PER_SOCKS: %w", err)
|
return Config{}, fmt.Errorf("parse MAX_QUEUE_BYTES_PER_SOCKS: %w", err)
|
||||||
}
|
}
|
||||||
cfg.MaxQueueBytesPerSOCKS = size
|
cfg.MaxQueueBytesPerSOCKS = size
|
||||||
|
case "ACK_TIMEOUT_MS":
|
||||||
|
timeout, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("parse ACK_TIMEOUT_MS: %w", err)
|
||||||
|
}
|
||||||
|
cfg.AckTimeoutMS = timeout
|
||||||
|
case "MAX_RETRY_COUNT":
|
||||||
|
count, err := strconv.Atoi(value)
|
||||||
|
if err != nil {
|
||||||
|
return Config{}, fmt.Errorf("parse MAX_RETRY_COUNT: %w", err)
|
||||||
|
}
|
||||||
|
cfg.MaxRetryCount = count
|
||||||
case "SESSION_IDLE_TIMEOUT_MS":
|
case "SESSION_IDLE_TIMEOUT_MS":
|
||||||
timeout, err := strconv.Atoi(value)
|
timeout, err := strconv.Atoi(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -200,6 +216,12 @@ func (c Config) ValidateClient() error {
|
|||||||
if c.WorkerPollIntervalMS < 1 {
|
if c.WorkerPollIntervalMS < 1 {
|
||||||
return fmt.Errorf("invalid WORKER_POLL_INTERVAL_MS: %d", c.WorkerPollIntervalMS)
|
return fmt.Errorf("invalid WORKER_POLL_INTERVAL_MS: %d", c.WorkerPollIntervalMS)
|
||||||
}
|
}
|
||||||
|
if c.AckTimeoutMS < 1 {
|
||||||
|
return fmt.Errorf("invalid ACK_TIMEOUT_MS: %d", c.AckTimeoutMS)
|
||||||
|
}
|
||||||
|
if c.MaxRetryCount < 0 {
|
||||||
|
return fmt.Errorf("invalid MAX_RETRY_COUNT: %d", c.MaxRetryCount)
|
||||||
|
}
|
||||||
if c.MaxQueueBytesPerSOCKS < c.MaxChunkSize {
|
if c.MaxQueueBytesPerSOCKS < c.MaxChunkSize {
|
||||||
return fmt.Errorf("MAX_QUEUE_BYTES_PER_SOCKS must be >= MAX_CHUNK_SIZE")
|
return fmt.Errorf("MAX_QUEUE_BYTES_PER_SOCKS must be >= MAX_CHUNK_SIZE")
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user