mirror of
https://github.com/masterking32/MasterHttpRelayVPN.git
synced 2026-05-18 06:44:35 +03:00
351 lines
11 KiB
Go
351 lines
11 KiB
Go
package server
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"net/http/httptest"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"masterhttprelayvpn/internal/config"
|
|
"masterhttprelayvpn/internal/logger"
|
|
"masterhttprelayvpn/internal/protocol"
|
|
)
|
|
|
|
func TestDrainSessionOutboundLockedRespectsGlobalLimits(t *testing.T) {
|
|
srv := &Server{
|
|
cfg: config.Config{
|
|
MaxPacketsPerBatch: 2,
|
|
MaxBatchBytes: 10,
|
|
},
|
|
}
|
|
|
|
session := &ClientSession{
|
|
ClientSessionKey: "client-session",
|
|
SOCKSConnections: map[uint64]*SOCKSState{
|
|
1: {ID: 1},
|
|
2: {ID: 2},
|
|
3: {ID: 3},
|
|
},
|
|
}
|
|
|
|
session.SOCKSConnections[1].OutboundQueue = []protocol.Packet{
|
|
testDataPacket("client-session", 1, 1, "abcd"),
|
|
}
|
|
session.SOCKSConnections[1].QueuedBytes = 4
|
|
|
|
session.SOCKSConnections[2].OutboundQueue = []protocol.Packet{
|
|
testDataPacket("client-session", 2, 1, "efgh"),
|
|
}
|
|
session.SOCKSConnections[2].QueuedBytes = 4
|
|
|
|
session.SOCKSConnections[3].OutboundQueue = []protocol.Packet{
|
|
testDataPacket("client-session", 3, 1, "ijkl"),
|
|
}
|
|
session.SOCKSConnections[3].QueuedBytes = 4
|
|
|
|
drained := srv.drainSessionOutboundLocked(session)
|
|
if len(drained) != 2 {
|
|
t.Fatalf("expected 2 drained packets, got %d", len(drained))
|
|
}
|
|
|
|
totalBytes := 0
|
|
for _, packet := range drained {
|
|
totalBytes += len(packet.Payload)
|
|
}
|
|
if totalBytes > srv.cfg.MaxBatchBytes {
|
|
t.Fatalf("expected drained bytes <= %d, got %d", srv.cfg.MaxBatchBytes, totalBytes)
|
|
}
|
|
|
|
remainingPackets := 0
|
|
for _, socksState := range session.SOCKSConnections {
|
|
remainingPackets += len(socksState.OutboundQueue)
|
|
}
|
|
if remainingPackets != 1 {
|
|
t.Fatalf("expected one packet to remain queued, got %d", remainingPackets)
|
|
}
|
|
}
|
|
|
|
func TestSOCKSStateInboundReorderQueuesUntilGapFilled(t *testing.T) {
|
|
socksState := &SOCKSState{
|
|
ConnectAcked: true,
|
|
PendingInbound: make(map[uint64][]protocol.PendingPacket),
|
|
MaxQueueBytes: 1024,
|
|
}
|
|
|
|
packet2 := testDataPacket("client-session", 1, 2, "two")
|
|
ready, duplicate, overflow := socksState.queueInboundPacketLocked(packet2, time.Now(), 8)
|
|
if duplicate || overflow {
|
|
t.Fatalf("unexpected duplicate=%t overflow=%t", duplicate, overflow)
|
|
}
|
|
if len(ready) != 0 {
|
|
t.Fatalf("expected no ready packets before sequence gap is filled, got %d", len(ready))
|
|
}
|
|
|
|
packet1 := testDataPacket("client-session", 1, 1, "one")
|
|
ready, duplicate, overflow = socksState.queueInboundPacketLocked(packet1, time.Now(), 8)
|
|
if duplicate || overflow {
|
|
t.Fatalf("unexpected duplicate=%t overflow=%t", duplicate, overflow)
|
|
}
|
|
if len(ready) != 2 {
|
|
t.Fatalf("expected 2 ready packets after filling sequence gap, got %d", len(ready))
|
|
}
|
|
if ready[0].Sequence != 1 || ready[1].Sequence != 2 {
|
|
t.Fatalf("expected ordered sequences [1 2], got [%d %d]", ready[0].Sequence, ready[1].Sequence)
|
|
}
|
|
}
|
|
|
|
func TestSOCKSStateInboundGapTimeout(t *testing.T) {
|
|
socksState := &SOCKSState{
|
|
PendingInbound: make(map[uint64][]protocol.PendingPacket),
|
|
}
|
|
socksState.PendingInbound[3] = []protocol.PendingPacket{{
|
|
Packet: testDataPacket("client-session", 1, 3, "late"),
|
|
QueuedAt: time.Now().Add(-2 * time.Second),
|
|
}}
|
|
|
|
if !socksState.hasExpiredInboundGapLocked(time.Now(), 500*time.Millisecond) {
|
|
t.Fatal("expected inbound gap timeout to trigger")
|
|
}
|
|
if len(socksState.PendingInbound) != 0 {
|
|
t.Fatalf("expected pending inbound buffer to be cleared, got %d items", len(socksState.PendingInbound))
|
|
}
|
|
}
|
|
|
|
func TestSOCKSStateInboundDataWaitsForConnect(t *testing.T) {
|
|
socksState := &SOCKSState{
|
|
PendingInbound: make(map[uint64][]protocol.PendingPacket),
|
|
}
|
|
|
|
packet1 := testDataPacket("client-session", 1, 1, "one")
|
|
ready, duplicate, overflow := socksState.queueInboundPacketLocked(packet1, time.Now(), 8)
|
|
if duplicate || overflow {
|
|
t.Fatalf("unexpected duplicate=%t overflow=%t", duplicate, overflow)
|
|
}
|
|
if len(ready) != 0 {
|
|
t.Fatalf("expected packet to stay buffered before connect, got %d ready packets", len(ready))
|
|
}
|
|
|
|
socksState.ConnectAcked = true
|
|
ready = socksState.drainReadyInboundLocked()
|
|
if len(ready) != 1 {
|
|
t.Fatalf("expected 1 ready packet after connect, got %d", len(ready))
|
|
}
|
|
if ready[0].Sequence != 1 {
|
|
t.Fatalf("expected sequence 1, got %d", ready[0].Sequence)
|
|
}
|
|
}
|
|
|
|
func TestSOCKSStateInboundReorderAllowsMultiplePacketTypesPerSequence(t *testing.T) {
|
|
socksState := &SOCKSState{
|
|
ConnectAcked: true,
|
|
PendingInbound: make(map[uint64][]protocol.PendingPacket),
|
|
}
|
|
|
|
closeWrite := protocol.NewPacket("client-session", protocol.PacketTypeSOCKSCloseWrite)
|
|
closeWrite.SOCKSID = 1
|
|
closeWrite.Sequence = 2
|
|
|
|
closeRead := protocol.NewPacket("client-session", protocol.PacketTypeSOCKSCloseRead)
|
|
closeRead.SOCKSID = 1
|
|
closeRead.Sequence = 2
|
|
|
|
ready, duplicate, overflow := socksState.queueInboundPacketLocked(closeWrite, time.Now(), 8)
|
|
if duplicate || overflow || len(ready) != 0 {
|
|
t.Fatalf("expected first close packet to buffer, duplicate=%t overflow=%t ready=%d", duplicate, overflow, len(ready))
|
|
}
|
|
|
|
ready, duplicate, overflow = socksState.queueInboundPacketLocked(closeRead, time.Now(), 8)
|
|
if duplicate || overflow || len(ready) != 0 {
|
|
t.Fatalf("expected second close packet on same sequence to buffer, duplicate=%t overflow=%t ready=%d", duplicate, overflow, len(ready))
|
|
}
|
|
|
|
data := testDataPacket("client-session", 1, 1, "one")
|
|
ready, duplicate, overflow = socksState.queueInboundPacketLocked(data, time.Now(), 8)
|
|
if duplicate || overflow || len(ready) != 3 {
|
|
t.Fatalf("expected data and both close packets to drain, duplicate=%t overflow=%t ready=%d", duplicate, overflow, len(ready))
|
|
}
|
|
if ready[0].Type != protocol.PacketTypeSOCKSData || ready[1].Type != protocol.PacketTypeSOCKSCloseRead || ready[2].Type != protocol.PacketTypeSOCKSCloseWrite {
|
|
t.Fatalf("unexpected drain order: %s, %s, %s", ready[0].Type, ready[1].Type, ready[2].Type)
|
|
}
|
|
}
|
|
|
|
func TestSOCKSStateReleaseClearsQueueState(t *testing.T) {
|
|
socksState := &SOCKSState{
|
|
Target: &protocol.Target{Host: "example.com", Port: 443},
|
|
OutboundQueue: []protocol.Packet{
|
|
testDataPacket("client-session", 1, 1, "hello"),
|
|
testDataPacket("client-session", 1, 2, "world"),
|
|
},
|
|
QueuedBytes: 10,
|
|
}
|
|
|
|
socksState.release()
|
|
|
|
if socksState.Target != nil {
|
|
t.Fatal("expected target to be cleared")
|
|
}
|
|
if len(socksState.OutboundQueue) != 0 {
|
|
t.Fatalf("expected empty outbound queue, got %d items", len(socksState.OutboundQueue))
|
|
}
|
|
if socksState.QueuedBytes != 0 {
|
|
t.Fatalf("expected queued bytes to be reset, got %d", socksState.QueuedBytes)
|
|
}
|
|
}
|
|
|
|
func TestProcessBatchBlockedSessionDoesNotBlockOtherSessions(t *testing.T) {
|
|
dialStarted := make(chan struct{})
|
|
releaseDial := make(chan struct{})
|
|
|
|
srv := New(config.Config{
|
|
MaxPacketsPerBatch: 8,
|
|
MaxBatchBytes: 1024,
|
|
MaxReorderBufferPackets: 8,
|
|
MaxServerQueueBytes: 1024,
|
|
}, logger.New("server-test", "ERROR"))
|
|
srv.dialUpstream = func(network string, address string, timeout time.Duration) (net.Conn, error) {
|
|
if address != "slow.example:80" {
|
|
return nil, errors.New("unexpected dial target")
|
|
}
|
|
close(dialStarted)
|
|
<-releaseDial
|
|
return nil, errors.New("forced dial failure")
|
|
}
|
|
|
|
connect := protocol.NewPacket("session-a", protocol.PacketTypeSOCKSConnect)
|
|
connect.SOCKSID = 1
|
|
connect.Sequence = 0
|
|
connect.Target = &protocol.Target{Host: "slow.example", Port: 80}
|
|
|
|
errCh := make(chan error, 1)
|
|
go func() {
|
|
_, err := srv.processBatch(protocol.NewBatch("session-a", "batch-a", []protocol.Packet{connect}))
|
|
errCh <- err
|
|
}()
|
|
|
|
select {
|
|
case <-dialStarted:
|
|
case <-time.After(500 * time.Millisecond):
|
|
t.Fatal("expected slow session dial to start")
|
|
}
|
|
|
|
ping := protocol.NewPacket("session-b", protocol.PacketTypePing)
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
_, err := srv.processBatch(protocol.NewBatch("session-b", "batch-b", []protocol.Packet{ping}))
|
|
done <- err
|
|
}()
|
|
|
|
select {
|
|
case err := <-done:
|
|
if err != nil {
|
|
t.Fatalf("expected unrelated session batch to complete, got error: %v", err)
|
|
}
|
|
case <-time.After(300 * time.Millisecond):
|
|
t.Fatal("expected unrelated session batch to complete while first session dial is blocked")
|
|
}
|
|
|
|
close(releaseDial)
|
|
|
|
select {
|
|
case err := <-errCh:
|
|
if err != nil {
|
|
t.Fatalf("expected blocked session batch to convert dial failure into response, got error: %v", err)
|
|
}
|
|
case <-time.After(500 * time.Millisecond):
|
|
t.Fatal("expected blocked session batch to finish after releasing dial")
|
|
}
|
|
}
|
|
|
|
func TestSOCKSStateNextOutboundSequenceIsConcurrentSafe(t *testing.T) {
|
|
socksState := &SOCKSState{}
|
|
|
|
const workers = 32
|
|
const iterationsPerWorker = 64
|
|
|
|
start := make(chan struct{})
|
|
var wg sync.WaitGroup
|
|
results := make(chan uint64, workers*iterationsPerWorker)
|
|
|
|
for i := 0; i < workers; i++ {
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
<-start
|
|
for j := 0; j < iterationsPerWorker; j++ {
|
|
results <- socksState.nextOutboundSequence()
|
|
}
|
|
}()
|
|
}
|
|
|
|
close(start)
|
|
wg.Wait()
|
|
close(results)
|
|
|
|
seen := make(map[uint64]struct{}, workers*iterationsPerWorker)
|
|
var maxSeq uint64
|
|
for seq := range results {
|
|
if _, exists := seen[seq]; exists {
|
|
t.Fatalf("duplicate outbound sequence generated: %d", seq)
|
|
}
|
|
seen[seq] = struct{}{}
|
|
if seq > maxSeq {
|
|
maxSeq = seq
|
|
}
|
|
}
|
|
|
|
expected := workers * iterationsPerWorker
|
|
if len(seen) != expected {
|
|
t.Fatalf("expected %d unique sequences, got %d", expected, len(seen))
|
|
}
|
|
if maxSeq != uint64(expected) {
|
|
t.Fatalf("expected max sequence %d, got %d", expected, maxSeq)
|
|
}
|
|
}
|
|
|
|
func TestSOCKSStateCloseUpstreamClearsConnectionSnapshot(t *testing.T) {
|
|
serverConn, clientConn := net.Pipe()
|
|
defer clientConn.Close()
|
|
|
|
socksState := &SOCKSState{}
|
|
socksState.setUpstreamConn(serverConn)
|
|
|
|
if err := socksState.closeUpstream(); err != nil {
|
|
t.Fatalf("expected closeUpstream to succeed, got %v", err)
|
|
}
|
|
|
|
if _, ok := socksState.currentUpstreamConn(); ok {
|
|
t.Fatal("expected upstream connection snapshot to be cleared after close")
|
|
}
|
|
}
|
|
|
|
func TestHandleRelayAppliesAntiBufferingHeaders(t *testing.T) {
|
|
srv := New(config.Config{}, nil)
|
|
request := httptest.NewRequest("GET", "/relay", nil)
|
|
recorder := httptest.NewRecorder()
|
|
|
|
srv.handleRelay(recorder, request)
|
|
|
|
if got := recorder.Header().Get("X-Accel-Buffering"); got != "no" {
|
|
t.Fatalf("expected X-Accel-Buffering=no, got %q", got)
|
|
}
|
|
if got := recorder.Header().Get("Cache-Control"); got != "no-store, no-cache, must-revalidate" {
|
|
t.Fatalf("unexpected Cache-Control header: %q", got)
|
|
}
|
|
if got := recorder.Header().Get("Pragma"); got != "no-cache" {
|
|
t.Fatalf("unexpected Pragma header: %q", got)
|
|
}
|
|
if got := recorder.Header().Get("Expires"); got != "0" {
|
|
t.Fatalf("unexpected Expires header: %q", got)
|
|
}
|
|
}
|
|
|
|
func testDataPacket(clientSessionKey string, socksID uint64, sequence uint64, payload string) protocol.Packet {
|
|
packet := protocol.NewPacket(clientSessionKey, protocol.PacketTypeSOCKSData)
|
|
packet.SOCKSID = socksID
|
|
packet.Sequence = sequence
|
|
packet.Payload = []byte(payload)
|
|
return packet
|
|
}
|