mirror of
https://github.com/masterking32/MasterHttpRelayVPN.git
synced 2026-05-18 23:54:37 +03:00
Refactor server locking to process batches per session instead of under a global mutex
This commit is contained in:
+64
-18
@@ -26,11 +26,13 @@ type Server struct {
|
||||
cfg config.Config
|
||||
log *logger.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*ClientSession
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*ClientSession
|
||||
dialUpstream func(network string, address string, timeout time.Duration) (net.Conn, error)
|
||||
}
|
||||
|
||||
type ClientSession struct {
|
||||
mu sync.Mutex
|
||||
ClientSessionKey string
|
||||
CreatedAt time.Time
|
||||
LastActivityAt time.Time
|
||||
@@ -71,9 +73,10 @@ type PendingInboundPacket struct {
|
||||
|
||||
func New(cfg config.Config, lg *logger.Logger) *Server {
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
log: lg,
|
||||
sessions: make(map[string]*ClientSession),
|
||||
cfg: cfg,
|
||||
log: lg,
|
||||
sessions: make(map[string]*ClientSession),
|
||||
dialUpstream: net.DialTimeout,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -171,7 +174,9 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) {
|
||||
session := s.getOrCreateSession(batch.ClientSessionKey)
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
session.mu.Lock()
|
||||
defer session.mu.Unlock()
|
||||
|
||||
session.LastActivityAt = now
|
||||
|
||||
responses := make([]protocol.Packet, 0, len(batch.Packets))
|
||||
@@ -183,7 +188,6 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) {
|
||||
)
|
||||
orderedResponses, err := s.processPacketLocked(session, packet, now)
|
||||
if err != nil {
|
||||
s.mu.Unlock()
|
||||
return protocol.Batch{}, err
|
||||
}
|
||||
for _, response := range orderedResponses {
|
||||
@@ -201,7 +205,6 @@ func (s *Server) processBatch(batch protocol.Batch) (protocol.Batch, error) {
|
||||
)
|
||||
responses = append(responses, outbound)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
if len(responses) == 0 {
|
||||
return protocol.Batch{}, nil
|
||||
@@ -250,7 +253,8 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
|
||||
"<gray>dial upstream socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan></gray>",
|
||||
packet.SOCKSID, packet.Target.Address(), session.ClientSessionKey,
|
||||
)
|
||||
upstreamConn, err := net.DialTimeout("tcp", packet.Target.Address(), 10*time.Second)
|
||||
upstreamConn, err := s.dial("tcp", packet.Target.Address(), 10*time.Second)
|
||||
|
||||
if err != nil {
|
||||
s.log.Warnf(
|
||||
"<yellow>upstream dial failed socks_id=<cyan>%d</cyan> target=<cyan>%s</cyan> client_session_key=<cyan>%s</cyan> error=<cyan>%v</cyan></yellow>",
|
||||
@@ -262,6 +266,7 @@ func (s *Server) processPacketLocked(session *ClientSession, packet protocol.Pac
|
||||
response.Payload = []byte(err.Error())
|
||||
return []protocol.Packet{response}, nil
|
||||
}
|
||||
|
||||
socksState.UpstreamConn = upstreamConn
|
||||
socksState.ConnectAcked = true
|
||||
s.log.Infof(
|
||||
@@ -887,10 +892,16 @@ func (s *Server) cleanupExpired() {
|
||||
socksTTL := time.Duration(s.cfg.SOCKSIdleTimeoutMS) * time.Millisecond
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.mu.RLock()
|
||||
sessionSnapshots := make(map[string]*ClientSession, len(s.sessions))
|
||||
for clientSessionKey, session := range s.sessions {
|
||||
sessionSnapshots[clientSessionKey] = session
|
||||
}
|
||||
s.mu.RUnlock()
|
||||
|
||||
sessionsToDelete := make([]string, 0)
|
||||
for clientSessionKey, session := range sessionSnapshots {
|
||||
session.mu.Lock()
|
||||
for socksID, socksState := range session.SOCKSConnections {
|
||||
if now.Sub(socksState.LastActivityAt) > socksTTL {
|
||||
targetAddress := targetAddressForLog(socksState.Target)
|
||||
@@ -902,23 +913,58 @@ func (s *Server) cleanupExpired() {
|
||||
}
|
||||
|
||||
if len(session.SOCKSConnections) == 0 && now.Sub(session.LastActivityAt) > sessionTTL {
|
||||
delete(s.sessions, clientSessionKey)
|
||||
s.log.Infof("<yellow>expired client session <cyan>%s</cyan></yellow>", clientSessionKey)
|
||||
sessionsToDelete = append(sessionsToDelete, clientSessionKey)
|
||||
}
|
||||
session.mu.Unlock()
|
||||
}
|
||||
|
||||
if len(sessionsToDelete) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for _, clientSessionKey := range sessionsToDelete {
|
||||
session := s.sessions[clientSessionKey]
|
||||
if session == nil {
|
||||
continue
|
||||
}
|
||||
session.mu.Lock()
|
||||
expired := len(session.SOCKSConnections) == 0 && now.Sub(session.LastActivityAt) > sessionTTL
|
||||
session.mu.Unlock()
|
||||
if !expired {
|
||||
continue
|
||||
}
|
||||
delete(s.sessions, clientSessionKey)
|
||||
s.log.Infof("<yellow>expired client session <cyan>%s</cyan></yellow>", clientSessionKey)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) SessionSnapshot() (sessions int, socksConnections int) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
sessions = len(s.sessions)
|
||||
sessionList := make([]*ClientSession, 0, len(s.sessions))
|
||||
for _, session := range s.sessions {
|
||||
socksConnections += len(session.SOCKSConnections)
|
||||
sessionList = append(sessionList, session)
|
||||
}
|
||||
sessions = len(sessionList)
|
||||
s.mu.RUnlock()
|
||||
|
||||
for _, session := range sessionList {
|
||||
session.mu.Lock()
|
||||
socksConnections += len(session.SOCKSConnections)
|
||||
session.mu.Unlock()
|
||||
}
|
||||
|
||||
return sessions, socksConnections
|
||||
}
|
||||
|
||||
func (s *Server) dial(network string, address string, timeout time.Duration) (net.Conn, error) {
|
||||
if s.dialUpstream == nil {
|
||||
return net.DialTimeout(network, address, timeout)
|
||||
}
|
||||
return s.dialUpstream(network, address, timeout)
|
||||
}
|
||||
|
||||
func LocalListenAddress(host string, port int) string {
|
||||
return net.JoinHostPort(host, fmt.Sprintf("%d", port))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user