mirror of
https://github.com/masterking32/MasterHttpRelayVPN.git
synced 2026-05-17 21:24:37 +03:00
Improve half-close handling and queue backpressure
This commit is contained in:
@@ -16,6 +16,7 @@ MAX_BATCH_BYTES = 262144
|
||||
WORKER_COUNT = 4
|
||||
HTTP_REQUEST_TIMEOUT_MS = 15000
|
||||
WORKER_POLL_INTERVAL_MS = 200
|
||||
IDLE_POLL_INTERVAL_MS = 1000
|
||||
MAX_QUEUE_BYTES_PER_SOCKS = 1048576
|
||||
ACK_TIMEOUT_MS = 5000
|
||||
MAX_RETRY_COUNT = 5
|
||||
|
||||
@@ -151,7 +151,7 @@ func (c *Client) buildPollBatch(connections []*SOCKSConnection) (protocol.Batch,
|
||||
now := time.Now()
|
||||
nowUnixMS := now.UnixMilli()
|
||||
lastUnixMS := c.lastPollUnixMS.Load()
|
||||
minInterval := time.Duration(c.cfg.WorkerPollIntervalMS) * time.Millisecond
|
||||
minInterval := time.Duration(c.cfg.IdlePollIntervalMS) * time.Millisecond
|
||||
if lastUnixMS > 0 && nowUnixMS-lastUnixMS < minInterval.Milliseconds() {
|
||||
return protocol.Batch{}, false
|
||||
}
|
||||
@@ -304,7 +304,26 @@ func (c *Client) applyResponsePacket(packet protocol.Packet) error {
|
||||
socksConn.LastActivityAt = time.Now()
|
||||
return socksConn.WriteToLocal(packet.Payload)
|
||||
|
||||
case protocol.PacketTypeSOCKSCloseRead, protocol.PacketTypeSOCKSCloseWrite, protocol.PacketTypeSOCKSRST:
|
||||
case protocol.PacketTypeSOCKSCloseRead:
|
||||
_ = socksConn.AckPacket(packet)
|
||||
socksConn.LastActivityAt = time.Now()
|
||||
if err := socksConn.CloseLocalWrite(); err != nil {
|
||||
return err
|
||||
}
|
||||
if socksConn.BothLocalSidesClosed() {
|
||||
return socksConn.CloseLocal()
|
||||
}
|
||||
return nil
|
||||
|
||||
case protocol.PacketTypeSOCKSCloseWrite:
|
||||
_ = socksConn.AckPacket(packet)
|
||||
socksConn.LastActivityAt = time.Now()
|
||||
if socksConn.BothLocalSidesClosed() {
|
||||
return socksConn.CloseLocal()
|
||||
}
|
||||
return nil
|
||||
|
||||
case protocol.PacketTypeSOCKSRST:
|
||||
_ = socksConn.AckPacket(packet)
|
||||
socksConn.LastActivityAt = time.Now()
|
||||
return socksConn.CloseLocal()
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"sync"
|
||||
@@ -38,6 +39,11 @@ type SOCKSConnection struct {
|
||||
|
||||
LocalConn net.Conn
|
||||
localWriteMu sync.Mutex
|
||||
localCloseMu sync.Mutex
|
||||
localReadEOF bool
|
||||
localWriteEOF bool
|
||||
closedC chan struct{}
|
||||
closeOnce sync.Once
|
||||
connectResultC chan error
|
||||
queueMu sync.Mutex
|
||||
OutboundQueue []*SOCKSOutboundQueueItem
|
||||
@@ -74,6 +80,7 @@ func (s *SOCKSConnectionStore) New(clientSessionKey string, clientAddress string
|
||||
CreatedAt: now,
|
||||
LastActivityAt: now,
|
||||
ClientAddress: clientAddress,
|
||||
closedC: make(chan struct{}),
|
||||
connectResultC: make(chan error, 1),
|
||||
InFlight: make(map[string]*SOCKSOutboundQueueItem),
|
||||
}
|
||||
@@ -112,13 +119,65 @@ func (s *SOCKSConnection) WriteToLocal(payload []byte) error {
|
||||
}
|
||||
|
||||
func (s *SOCKSConnection) CloseLocal() error {
|
||||
s.localWriteMu.Lock()
|
||||
defer s.localWriteMu.Unlock()
|
||||
var err error
|
||||
s.closeOnce.Do(func() {
|
||||
s.localWriteMu.Lock()
|
||||
defer s.localWriteMu.Unlock()
|
||||
if s.LocalConn != nil {
|
||||
err = s.LocalConn.Close()
|
||||
}
|
||||
close(s.closedC)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
if s.LocalConn == nil {
|
||||
func (s *SOCKSConnection) CloseLocalWrite() error {
|
||||
s.localCloseMu.Lock()
|
||||
defer s.localCloseMu.Unlock()
|
||||
|
||||
if s.localWriteEOF {
|
||||
return nil
|
||||
}
|
||||
return s.LocalConn.Close()
|
||||
s.localWriteEOF = true
|
||||
|
||||
if tcpConn, ok := s.LocalConn.(*net.TCPConn); ok {
|
||||
return tcpConn.CloseWrite()
|
||||
}
|
||||
return s.CloseLocal()
|
||||
}
|
||||
|
||||
func (s *SOCKSConnection) CloseLocalRead() error {
|
||||
s.localCloseMu.Lock()
|
||||
defer s.localCloseMu.Unlock()
|
||||
|
||||
if s.localReadEOF {
|
||||
return nil
|
||||
}
|
||||
s.localReadEOF = true
|
||||
|
||||
if tcpConn, ok := s.LocalConn.(*net.TCPConn); ok {
|
||||
return tcpConn.CloseRead()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SOCKSConnection) MarkLocalReadEOF() {
|
||||
s.localCloseMu.Lock()
|
||||
s.localReadEOF = true
|
||||
s.localCloseMu.Unlock()
|
||||
}
|
||||
|
||||
func (s *SOCKSConnection) BothLocalSidesClosed() bool {
|
||||
s.localCloseMu.Lock()
|
||||
defer s.localCloseMu.Unlock()
|
||||
return s.localReadEOF && s.localWriteEOF
|
||||
}
|
||||
|
||||
func (s *SOCKSConnection) WaitUntilClosed(ctx context.Context) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-s.closedC:
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SOCKSConnectionStore) Get(id uint64) *SOCKSConnection {
|
||||
|
||||
@@ -316,7 +316,9 @@ func (c *Client) captureInitialPayload(ctx context.Context, conn net.Conn, socks
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
socksConn.MarkLocalReadEOF()
|
||||
_ = socksConn.EnqueuePacket(socksConn.BuildSOCKSCloseWritePacket())
|
||||
socksConn.WaitUntilClosed(ctx)
|
||||
return nil
|
||||
}
|
||||
if ne, ok := err.(net.Error); ok && ne.Timeout() {
|
||||
|
||||
@@ -32,12 +32,14 @@ type Config struct {
|
||||
WorkerCount int
|
||||
HTTPRequestTimeoutMS int
|
||||
WorkerPollIntervalMS int
|
||||
IdlePollIntervalMS int
|
||||
MaxQueueBytesPerSOCKS int
|
||||
AckTimeoutMS int
|
||||
MaxRetryCount int
|
||||
SessionIdleTimeoutMS int
|
||||
SOCKSIdleTimeoutMS int
|
||||
ReadBodyLimitBytes int
|
||||
MaxServerQueueBytes int
|
||||
}
|
||||
|
||||
func Load(path string) (Config, error) {
|
||||
@@ -53,12 +55,14 @@ func Load(path string) (Config, error) {
|
||||
WorkerCount: 4,
|
||||
HTTPRequestTimeoutMS: 15000,
|
||||
WorkerPollIntervalMS: 200,
|
||||
IdlePollIntervalMS: 1000,
|
||||
MaxQueueBytesPerSOCKS: 1024 * 1024,
|
||||
AckTimeoutMS: 5000,
|
||||
MaxRetryCount: 5,
|
||||
SessionIdleTimeoutMS: 5 * 60 * 1000,
|
||||
SOCKSIdleTimeoutMS: 2 * 60 * 1000,
|
||||
ReadBodyLimitBytes: 2 * 1024 * 1024,
|
||||
MaxServerQueueBytes: 2 * 1024 * 1024,
|
||||
}
|
||||
|
||||
file, err := os.Open(path)
|
||||
@@ -151,6 +155,12 @@ func Load(path string) (Config, error) {
|
||||
return Config{}, fmt.Errorf("parse WORKER_POLL_INTERVAL_MS: %w", err)
|
||||
}
|
||||
cfg.WorkerPollIntervalMS = interval
|
||||
case "IDLE_POLL_INTERVAL_MS":
|
||||
interval, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("parse IDLE_POLL_INTERVAL_MS: %w", err)
|
||||
}
|
||||
cfg.IdlePollIntervalMS = interval
|
||||
case "MAX_QUEUE_BYTES_PER_SOCKS":
|
||||
size, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
@@ -187,6 +197,12 @@ func Load(path string) (Config, error) {
|
||||
return Config{}, fmt.Errorf("parse READ_BODY_LIMIT_BYTES: %w", err)
|
||||
}
|
||||
cfg.ReadBodyLimitBytes = size
|
||||
case "MAX_SERVER_QUEUE_BYTES":
|
||||
size, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return Config{}, fmt.Errorf("parse MAX_SERVER_QUEUE_BYTES: %w", err)
|
||||
}
|
||||
cfg.MaxServerQueueBytes = size
|
||||
}
|
||||
}
|
||||
|
||||
@@ -216,6 +232,9 @@ func (c Config) ValidateClient() error {
|
||||
if c.WorkerPollIntervalMS < 1 {
|
||||
return fmt.Errorf("invalid WORKER_POLL_INTERVAL_MS: %d", c.WorkerPollIntervalMS)
|
||||
}
|
||||
if c.IdlePollIntervalMS < c.WorkerPollIntervalMS {
|
||||
return fmt.Errorf("IDLE_POLL_INTERVAL_MS must be >= WORKER_POLL_INTERVAL_MS")
|
||||
}
|
||||
if c.AckTimeoutMS < 1 {
|
||||
return fmt.Errorf("invalid ACK_TIMEOUT_MS: %d", c.AckTimeoutMS)
|
||||
}
|
||||
@@ -244,6 +263,9 @@ func (c Config) ValidateServer() error {
|
||||
if c.ReadBodyLimitBytes < c.MaxChunkSize {
|
||||
return fmt.Errorf("READ_BODY_LIMIT_BYTES must be >= MAX_CHUNK_SIZE")
|
||||
}
|
||||
if c.MaxServerQueueBytes < c.MaxChunkSize {
|
||||
return fmt.Errorf("MAX_SERVER_QUEUE_BYTES must be >= MAX_CHUNK_SIZE")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
+97
-21
@@ -13,6 +13,7 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -51,9 +52,13 @@ type SOCKSState struct {
|
||||
OutboundSequence uint64
|
||||
UpstreamConn net.Conn
|
||||
upstreamWriteMu sync.Mutex
|
||||
upstreamCloseMu sync.Mutex
|
||||
upstreamReadEOF bool
|
||||
upstreamWriteEOF bool
|
||||
queueMu sync.Mutex
|
||||
OutboundQueue []protocol.Packet
|
||||
QueuedBytes int
|
||||
MaxQueueBytes int
|
||||
}
|
||||
|
||||
func New(cfg config.Config, lg *logger.Logger) *Server {
|
||||
@@ -215,6 +220,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
|
||||
Target: packet.Target,
|
||||
ConnectSeen: true,
|
||||
LastSequenceSeen: packet.Sequence,
|
||||
MaxQueueBytes: s.cfg.MaxServerQueueBytes,
|
||||
}
|
||||
session.SOCKSConnections[packet.SOCKSID] = socksState
|
||||
} else {
|
||||
@@ -301,7 +307,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
|
||||
"<gray>received close_read socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
|
||||
packet.SOCKSID, packet.Sequence, session.ClientSessionKey,
|
||||
)
|
||||
_ = socksState.closeUpstream()
|
||||
_ = socksState.closeUpstreamRead()
|
||||
return &response, nil
|
||||
|
||||
case protocol.PacketTypeSOCKSCloseWrite:
|
||||
@@ -318,7 +324,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
|
||||
"<gray>received close_write socks_id=<cyan>%d</cyan> seq=<cyan>%d</cyan> client_session_key=<cyan>%s</cyan></gray>",
|
||||
packet.SOCKSID, packet.Sequence, session.ClientSessionKey,
|
||||
)
|
||||
_ = socksState.closeUpstream()
|
||||
_ = socksState.closeUpstreamWrite()
|
||||
return &response, nil
|
||||
|
||||
case protocol.PacketTypeSOCKSRST:
|
||||
@@ -408,6 +414,7 @@ func (s *Server) getOrCreateSOCKSStateLocked(session *ClientSession, packet prot
|
||||
LastActivityAt: now,
|
||||
Target: packet.Target,
|
||||
LastSequenceSeen: packet.Sequence,
|
||||
MaxQueueBytes: s.cfg.MaxServerQueueBytes,
|
||||
}
|
||||
session.SOCKSConnections[packet.SOCKSID] = socksState
|
||||
s.log.Debugf(
|
||||
@@ -442,7 +449,15 @@ func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSStat
|
||||
n, err := socksState.UpstreamConn.Read(buffer)
|
||||
if n > 0 {
|
||||
chunk := append([]byte(nil), buffer[:n]...)
|
||||
socksState.enqueueOutboundData(clientSessionKey, chunk, false)
|
||||
if !socksState.enqueueOutboundData(clientSessionKey, chunk, false) {
|
||||
s.log.Warnf(
|
||||
"<yellow>server outbound queue full socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></yellow>",
|
||||
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey,
|
||||
)
|
||||
socksState.forceResetPacket(clientSessionKey)
|
||||
_ = socksState.closeUpstream()
|
||||
return
|
||||
}
|
||||
socksState.LastActivityAt = time.Now()
|
||||
queueDepth, queueBytes := socksState.queueSnapshot()
|
||||
s.log.Debugf(
|
||||
@@ -457,14 +472,22 @@ func (s *Server) upstreamReadLoop(clientSessionKey string, socksState *SOCKSStat
|
||||
"<gray>upstream eof socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></gray>",
|
||||
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey,
|
||||
)
|
||||
socksState.enqueueControlPacket(clientSessionKey, protocol.PacketTypeSOCKSCloseRead, true)
|
||||
} else {
|
||||
s.log.Warnf(
|
||||
"<yellow>upstream read failed socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan> error=<cyan>%v</cyan></yellow>",
|
||||
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey, err,
|
||||
)
|
||||
socksState.enqueueControlPacket(clientSessionKey, protocol.PacketTypeSOCKSRST, true)
|
||||
_ = socksState.enqueueControlPacket(clientSessionKey, protocol.PacketTypeSOCKSCloseRead, true)
|
||||
_ = socksState.closeUpstreamRead()
|
||||
return
|
||||
}
|
||||
if isClosedConnError(err) {
|
||||
s.log.Debugf(
|
||||
"<gray>upstream closed locally socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></gray>",
|
||||
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey,
|
||||
)
|
||||
return
|
||||
}
|
||||
s.log.Warnf(
|
||||
"<yellow>upstream read failed socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan> error=<cyan>%v</cyan></yellow>",
|
||||
socksState.ID, targetAddressForLog(socksState.Target), clientSessionKey, err,
|
||||
)
|
||||
socksState.forceResetPacket(clientSessionKey)
|
||||
_ = socksState.closeUpstream()
|
||||
return
|
||||
}
|
||||
@@ -476,28 +499,45 @@ func (s *SOCKSState) nextOutboundSequence() uint64 {
|
||||
return s.OutboundSequence
|
||||
}
|
||||
|
||||
func (s *SOCKSState) enqueueOutboundData(clientSessionKey string, payload []byte, final bool) {
|
||||
func (s *SOCKSState) enqueueOutboundData(clientSessionKey string, payload []byte, final bool) bool {
|
||||
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData)
|
||||
packet.SOCKSID = s.ID
|
||||
packet.Sequence = s.nextOutboundSequence()
|
||||
packet.Final = final
|
||||
packet.Payload = payload
|
||||
s.enqueuePacket(packet)
|
||||
return s.enqueuePacket(packet)
|
||||
}
|
||||
|
||||
func (s *SOCKSState) enqueueControlPacket(clientSessionKey string, packetType protocol.PacketType, final bool) {
|
||||
func (s *SOCKSState) enqueueControlPacket(clientSessionKey string, packetType protocol.PacketType, final bool) bool {
|
||||
packet := protocol.NewPacket(clientSessionKey, packetType)
|
||||
packet.SOCKSID = s.ID
|
||||
packet.Sequence = s.nextOutboundSequence()
|
||||
packet.Final = final
|
||||
s.enqueuePacket(packet)
|
||||
return s.enqueuePacket(packet)
|
||||
}
|
||||
|
||||
func (s *SOCKSState) enqueuePacket(packet protocol.Packet) {
|
||||
func (s *SOCKSState) enqueuePacket(packet protocol.Packet) bool {
|
||||
s.queueMu.Lock()
|
||||
defer s.queueMu.Unlock()
|
||||
packetBytes := len(packet.Payload)
|
||||
if s.MaxQueueBytes > 0 && s.QueuedBytes+packetBytes > s.MaxQueueBytes {
|
||||
return false
|
||||
}
|
||||
s.OutboundQueue = append(s.OutboundQueue, packet)
|
||||
s.QueuedBytes += len(packet.Payload)
|
||||
s.QueuedBytes += packetBytes
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *SOCKSState) forceResetPacket(clientSessionKey string) {
|
||||
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSRST)
|
||||
packet.SOCKSID = s.ID
|
||||
packet.Sequence = s.nextOutboundSequence()
|
||||
packet.Final = true
|
||||
|
||||
s.queueMu.Lock()
|
||||
defer s.queueMu.Unlock()
|
||||
s.OutboundQueue = []protocol.Packet{packet}
|
||||
s.QueuedBytes = 0
|
||||
}
|
||||
|
||||
func (s *SOCKSState) queueSnapshot() (items int, bytes int) {
|
||||
@@ -552,18 +592,47 @@ func (s *SOCKSState) writeUpstream(payload []byte) error {
|
||||
}
|
||||
|
||||
func (s *SOCKSState) closeUpstream() error {
|
||||
s.upstreamWriteMu.Lock()
|
||||
defer s.upstreamWriteMu.Unlock()
|
||||
s.upstreamCloseMu.Lock()
|
||||
defer s.upstreamCloseMu.Unlock()
|
||||
if s.UpstreamConn == nil {
|
||||
return nil
|
||||
}
|
||||
target := targetAddressForLog(s.Target)
|
||||
err := s.UpstreamConn.Close()
|
||||
s.UpstreamConn = nil
|
||||
if err == nil {
|
||||
s.upstreamReadEOF = true
|
||||
s.upstreamWriteEOF = true
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SOCKSState) closeUpstreamRead() error {
|
||||
s.upstreamCloseMu.Lock()
|
||||
defer s.upstreamCloseMu.Unlock()
|
||||
if s.upstreamReadEOF {
|
||||
return nil
|
||||
}
|
||||
_ = target
|
||||
s.upstreamReadEOF = true
|
||||
if tcpConn, ok := s.UpstreamConn.(*net.TCPConn); ok {
|
||||
return tcpConn.CloseRead()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SOCKSState) closeUpstreamWrite() error {
|
||||
s.upstreamCloseMu.Lock()
|
||||
defer s.upstreamCloseMu.Unlock()
|
||||
if s.upstreamWriteEOF {
|
||||
return nil
|
||||
}
|
||||
s.upstreamWriteEOF = true
|
||||
if tcpConn, ok := s.UpstreamConn.(*net.TCPConn); ok {
|
||||
return tcpConn.CloseWrite()
|
||||
}
|
||||
if s.UpstreamConn == nil {
|
||||
return nil
|
||||
}
|
||||
err := s.UpstreamConn.Close()
|
||||
s.UpstreamConn = nil
|
||||
s.upstreamReadEOF = true
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -626,3 +695,10 @@ func targetAddressForLog(target *protocol.Target) string {
|
||||
}
|
||||
return target.Address()
|
||||
}
|
||||
|
||||
func isClosedConnError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(err.Error(), "use of closed network connection")
|
||||
}
|
||||
|
||||
@@ -12,4 +12,5 @@ WORKER_COUNT = 4
|
||||
SESSION_IDLE_TIMEOUT_MS = 300000
|
||||
SOCKS_IDLE_TIMEOUT_MS = 120000
|
||||
READ_BODY_LIMIT_BYTES = 2097152
|
||||
MAX_SERVER_QUEUE_BYTES = 2097152
|
||||
# ==============================================================================
|
||||
|
||||
Reference in New Issue
Block a user