mirror of
https://github.com/therealaleph/MasterHttpRelayVPN-RUST.git
synced 2026-05-17 21:24:48 +03:00
perf(tunnel): zero-copy mux + base64 off mux thread (#881)
Performance refactor of full-tunnel mode hot path. No wire-protocol or
behavior changes — internal data flow only.
**1. Zero-copy reads via `Bytes`/`BytesMut`**
`tunnel_loop` and the SOCKS5 UDP receive loop drop their per-iteration
`Vec::to_vec()` copies. `MuxMsg::{ConnectData,Data,UdpOpen,UdpData}` now
carry `Bytes` instead of `Vec<u8>`/`Arc<Vec<u8>>`; the `Arc::try_unwrap`
dance for `pending_client_data` is gone (Bytes is already Arc-backed).
TCP path is threshold-based to avoid the obvious memory regression:
- n ≥ 32 KB: `BytesMut::split().freeze()` — saves the 64 KB memcpy on
hot downloads.
- n < 32 KB: `Bytes::copy_from_slice` + `buf.clear()` — payload-sized
retention. Without this split, a queued tiny TLS record would refcount-
pin the full 64 KB recv buffer (worst case ~96 MB on a backpressured
tunnel).
UDP path: fixed `Vec<u8>` recv buffer + `Bytes::copy_from_slice` after
the 9 KB MAX_UDP_PAYLOAD_BYTES guard. `parse_socks5_udp_packet` split
into `_offsets` + `&[u8]` wrapper so callers stay on the reusable buffer.
**2. Base64 encoding moved off the single mux thread**
New internal `PendingOp { data: Option<Bytes>, encode_empty: bool }`
flows through `mux_loop` with raw bytes. Actual `B64.encode(...)` runs
in `fire_batch`'s spawned task, after the per-deployment semaphore. Up
to ~3 MB of encoding per batch (50 ops × 64 KB) no longer serializes
the single mux task.
**3. Code quality**
- `BatchAccum::push_or_fire` collapses 4× ~25-line match arms → ~10 each.
- `should_fire(pending_len, payload_bytes, op_bytes)` extracted with
`saturating_add` for a self-contained contract.
- `encode_pending(p) -> BatchOp` extracted as a free function so the
encoding contract is directly testable.
**Tests:** 208/208 (was 200, +8 new):
- `encode_pending_*` × 4 — base64-encode contract per MuxMsg variant
- `should_fire_*` × 3 — first-op, MAX_BATCH_OPS boundary, payload cap
- `batch_accum_reindexes_after_flush` — regression test for post-flush
reply index lookup in `fire_batch`
**Public API:** `TunnelMux::udp_open` and `udp_data` now take
`data: impl Into<Bytes>` instead of `Vec<u8>`. Existing call sites
keep compiling.
Reviewed via Anthropic Claude.
Co-Authored-By: dazzling-no-more <noreply@github.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
+48
-12
@@ -3,6 +3,7 @@ use std::net::SocketAddr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use bytes::Bytes;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::{TcpListener, TcpStream, UdpSocket};
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
@@ -965,7 +966,7 @@ struct SocksUdpTarget {
|
||||
/// to abort mid-await.
|
||||
struct UdpRelaySession {
|
||||
sid: String,
|
||||
uplink: mpsc::Sender<Vec<u8>>,
|
||||
uplink: mpsc::Sender<Bytes>,
|
||||
}
|
||||
|
||||
/// All per-ASSOCIATE UDP relay state behind a single mutex so insertion
|
||||
@@ -991,7 +992,7 @@ impl UdpRelayState {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_uplink(&self, target: &SocksUdpTarget) -> Option<mpsc::Sender<Vec<u8>>> {
|
||||
fn get_uplink(&self, target: &SocksUdpTarget) -> Option<mpsc::Sender<Bytes>> {
|
||||
self.sessions.get(target).map(|s| s.uplink.clone())
|
||||
}
|
||||
|
||||
@@ -1118,7 +1119,15 @@ async fn handle_socks5_udp_associate(
|
||||
client_peer_ip
|
||||
);
|
||||
|
||||
let mut buf = vec![0u8; SOCKS5_UDP_RECV_BUF_BYTES];
|
||||
// Fixed reusable recv buffer. We deliberately don't go the
|
||||
// `BytesMut::split().freeze()` route here even though `tunnel_loop`
|
||||
// does: in TCP the read region IS the payload, but UDP always
|
||||
// slices the SOCKS5 header off, so we'd be copying out anyway —
|
||||
// and a frozen `Bytes` from the recv buf would refcount-pin the
|
||||
// full ~65 KB allocation behind a tiny DNS reply, ballooning
|
||||
// memory under bursts. Right-sized `Bytes::copy_from_slice` on
|
||||
// accepted payloads keeps retention proportional to actual data.
|
||||
let mut recv_buf = vec![0u8; SOCKS5_UDP_RECV_BUF_BYTES];
|
||||
let mut control_buf = [0u8; 1];
|
||||
let mut client_addr: Option<SocketAddr> = None;
|
||||
let state: Arc<Mutex<UdpRelayState>> = Arc::new(Mutex::new(UdpRelayState::new()));
|
||||
@@ -1134,7 +1143,7 @@ async fn handle_socks5_udp_associate(
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
recv = udp.recv_from(&mut buf) => {
|
||||
recv = udp.recv_from(&mut recv_buf) => {
|
||||
let (n, peer) = match recv {
|
||||
Ok(v) => v,
|
||||
Err(e) => {
|
||||
@@ -1142,6 +1151,7 @@ async fn handle_socks5_udp_associate(
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Source-IP check: anything not from the SOCKS5 client's
|
||||
// host is dropped silently.
|
||||
if peer.ip() != client_peer_ip {
|
||||
@@ -1162,9 +1172,10 @@ async fn handle_socks5_udp_associate(
|
||||
// can race one bad packet to DoS the legitimate client
|
||||
// (whose real datagram, sent from a different ephemeral
|
||||
// port, would then be silently rejected).
|
||||
let Some((target, payload)) = parse_socks5_udp_packet(&buf[..n]) else {
|
||||
let Some((target, payload_off)) = parse_socks5_udp_packet_offsets(&recv_buf[..n]) else {
|
||||
continue;
|
||||
};
|
||||
let payload_slice = &recv_buf[payload_off..n];
|
||||
|
||||
// Issue #213: client-side QUIC block. UDP/443 is
|
||||
// HTTP/3 — drop the datagram silently so the client
|
||||
@@ -1206,19 +1217,26 @@ async fn handle_socks5_udp_associate(
|
||||
// the mux. Each datagram costs ~payload * 1.33 in the
|
||||
// batched JSON envelope plus tunnel-node CPU; uncapped,
|
||||
// a runaway client can exhaust Apps Script quota.
|
||||
if payload.len() > MAX_UDP_PAYLOAD_BYTES {
|
||||
if payload_slice.len() > MAX_UDP_PAYLOAD_BYTES {
|
||||
oversized_dropped += 1;
|
||||
if oversized_dropped == 1 || oversized_dropped.is_multiple_of(100) {
|
||||
tracing::debug!(
|
||||
"udp datagram dropped: {} B > {} B (count={})",
|
||||
payload.len(),
|
||||
payload_slice.len(),
|
||||
MAX_UDP_PAYLOAD_BYTES,
|
||||
oversized_dropped,
|
||||
);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
let payload = payload.to_vec();
|
||||
|
||||
// Right-sized copy: the queued/in-flight payload owns its
|
||||
// own allocation, so the recv buffer can be reused on the
|
||||
// next iteration without keeping every queued datagram
|
||||
// alive. Sized to the actual payload (≤ MAX_UDP_PAYLOAD_BYTES
|
||||
// = 9 KB after the guard above), not the full ~65 KB recv
|
||||
// buffer.
|
||||
let payload = Bytes::copy_from_slice(payload_slice);
|
||||
|
||||
// Fast path: existing session — push payload onto its
|
||||
// bounded uplink queue, drop on overflow (UDP semantics).
|
||||
@@ -1292,7 +1310,7 @@ async fn handle_socks5_udp_associate(
|
||||
continue;
|
||||
}
|
||||
|
||||
let (uplink_tx, uplink_rx) = mpsc::channel::<Vec<u8>>(UDP_UPLINK_QUEUE);
|
||||
let (uplink_tx, uplink_rx) = mpsc::channel::<Bytes>(UDP_UPLINK_QUEUE);
|
||||
let task_mux = mux.clone();
|
||||
let task_udp = udp.clone();
|
||||
let task_target = target.clone();
|
||||
@@ -1365,7 +1383,7 @@ async fn udp_session_task(
|
||||
sid: String,
|
||||
target: SocksUdpTarget,
|
||||
client_addr: SocketAddr,
|
||||
mut uplink_rx: mpsc::Receiver<Vec<u8>>,
|
||||
mut uplink_rx: mpsc::Receiver<Bytes>,
|
||||
) {
|
||||
let mut backoff = UDP_INITIAL_POLL_DELAY;
|
||||
loop {
|
||||
@@ -1473,7 +1491,20 @@ async fn write_socks5_reply(
|
||||
sock.flush().await
|
||||
}
|
||||
|
||||
fn parse_socks5_udp_packet(buf: &[u8]) -> Option<(SocksUdpTarget, &[u8])> {
|
||||
/// Parse the SOCKS5 UDP frame header and return the target plus the byte
|
||||
/// offset at which the payload starts. Splitting "structure parsing"
|
||||
/// from "give me a payload slice" lets the recv hot path stay on a
|
||||
/// fixed reusable `Vec<u8>` buffer and only allocate a right-sized
|
||||
/// `Bytes::copy_from_slice(&recv_buf[off..n])` for accepted payloads
|
||||
/// (after the size guard). DO NOT change this back to a zero-copy
|
||||
/// `Bytes::slice` path: that was tried and reverted because slicing
|
||||
/// the recv buffer with `bytes` 1.x refcounts the whole ~65 KB
|
||||
/// allocation, so a queued tiny DNS reply pinned the full datagram-
|
||||
/// sized buffer until it drained — burst retention regressed by
|
||||
/// orders of magnitude on UDP-heavy workloads. The thin
|
||||
/// `parse_socks5_udp_packet` wrapper below keeps existing `&[u8]`
|
||||
/// callers (tests) working.
|
||||
fn parse_socks5_udp_packet_offsets(buf: &[u8]) -> Option<(SocksUdpTarget, usize)> {
|
||||
if buf.len() < 4 || buf[0] != 0 || buf[1] != 0 || buf[2] != 0 {
|
||||
return None;
|
||||
}
|
||||
@@ -1528,10 +1559,15 @@ fn parse_socks5_udp_packet(buf: &[u8]) -> Option<(SocksUdpTarget, &[u8])> {
|
||||
atyp,
|
||||
addr,
|
||||
},
|
||||
&buf[pos..],
|
||||
pos,
|
||||
))
|
||||
}
|
||||
|
||||
fn parse_socks5_udp_packet(buf: &[u8]) -> Option<(SocksUdpTarget, &[u8])> {
|
||||
let (target, off) = parse_socks5_udp_packet_offsets(buf)?;
|
||||
Some((target, &buf[off..]))
|
||||
}
|
||||
|
||||
fn build_socks5_udp_packet(target: &SocksUdpTarget, payload: &[u8]) -> Vec<u8> {
|
||||
let mut out = Vec::with_capacity(4 + target.addr.len() + 2 + payload.len() + 1);
|
||||
out.extend_from_slice(&[0, 0, 0, target.atyp]);
|
||||
|
||||
+421
-150
@@ -19,6 +19,7 @@ use std::time::{Duration, Instant};
|
||||
|
||||
use base64::engine::general_purpose::STANDARD as B64;
|
||||
use base64::Engine;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, oneshot, Semaphore};
|
||||
@@ -163,25 +164,26 @@ enum MuxMsg {
|
||||
ConnectData {
|
||||
host: String,
|
||||
port: u16,
|
||||
// Arc so the caller can hand the buffer to the mux AND keep a ref
|
||||
// for the fallback path without an extra 64 KB copy per session.
|
||||
data: Arc<Vec<u8>>,
|
||||
// `Bytes` is internally Arc-backed, so the caller can cheaply
|
||||
// clone() to keep its own reference for the unsupported-fallback
|
||||
// replay path without an extra 64 KB copy per session.
|
||||
data: Bytes,
|
||||
reply: BatchedReply,
|
||||
},
|
||||
Data {
|
||||
sid: String,
|
||||
data: Vec<u8>,
|
||||
data: Bytes,
|
||||
reply: BatchedReply,
|
||||
},
|
||||
UdpOpen {
|
||||
host: String,
|
||||
port: u16,
|
||||
data: Vec<u8>,
|
||||
data: Bytes,
|
||||
reply: BatchedReply,
|
||||
},
|
||||
UdpData {
|
||||
sid: String,
|
||||
data: Vec<u8>,
|
||||
data: Bytes,
|
||||
reply: BatchedReply,
|
||||
},
|
||||
Close {
|
||||
@@ -189,6 +191,25 @@ enum MuxMsg {
|
||||
},
|
||||
}
|
||||
|
||||
/// Raw, not-yet-encoded form of a batch operation. Lives only inside
|
||||
/// `mux_loop` and gets converted to `BatchOp` (with base64-encoded `d`)
|
||||
/// inside `fire_batch`'s spawned task — keeping the encoding work off
|
||||
/// the single mux thread, which previously had to base64 every op
|
||||
/// inline before it could move on to the next message.
|
||||
struct PendingOp {
|
||||
op: &'static str,
|
||||
sid: Option<String>,
|
||||
host: Option<String>,
|
||||
port: Option<u16>,
|
||||
/// Raw payload. `None` for empty polls / opless ops; `Some` even
|
||||
/// when empty preserves the connect_data shape (always emits `d`).
|
||||
data: Option<Bytes>,
|
||||
/// True for ops that must serialize `d` even when empty (currently
|
||||
/// only `connect_data`, which uses presence of `d` as the signal
|
||||
/// that the caller is opting into the bundled-first-bytes flow).
|
||||
encode_empty: bool,
|
||||
}
|
||||
|
||||
pub struct TunnelMux {
|
||||
tx: mpsc::Sender<MuxMsg>,
|
||||
/// Set to `true` after the first time the tunnel-node rejects
|
||||
@@ -316,13 +337,13 @@ impl TunnelMux {
|
||||
&self,
|
||||
host: &str,
|
||||
port: u16,
|
||||
data: Vec<u8>,
|
||||
data: impl Into<Bytes>,
|
||||
) -> Result<TunnelResponse, String> {
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.send(MuxMsg::UdpOpen {
|
||||
host: host.to_string(),
|
||||
port,
|
||||
data,
|
||||
data: data.into(),
|
||||
reply: reply_tx,
|
||||
})
|
||||
.await;
|
||||
@@ -333,11 +354,15 @@ impl TunnelMux {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn udp_data(&self, sid: &str, data: Vec<u8>) -> Result<TunnelResponse, String> {
|
||||
pub async fn udp_data(
|
||||
&self,
|
||||
sid: &str,
|
||||
data: impl Into<Bytes>,
|
||||
) -> Result<TunnelResponse, String> {
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
self.send(MuxMsg::UdpData {
|
||||
sid: sid.to_string(),
|
||||
data,
|
||||
data: data.into(),
|
||||
reply: reply_tx,
|
||||
})
|
||||
.await;
|
||||
@@ -619,10 +644,8 @@ async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>, c
|
||||
}
|
||||
|
||||
// Split: plain connects go parallel, data-bearing ops get batched.
|
||||
let mut data_ops: Vec<BatchOp> = Vec::new();
|
||||
let mut data_replies: Vec<(usize, BatchedReply)> = Vec::new();
|
||||
let mut accum = BatchAccum::new();
|
||||
let mut close_sids: Vec<String> = Vec::new();
|
||||
let mut batch_payload_bytes: usize = 0;
|
||||
|
||||
for msg in msgs {
|
||||
match msg {
|
||||
@@ -648,68 +671,28 @@ async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>, c
|
||||
data,
|
||||
reply,
|
||||
} => {
|
||||
let encoded = Some(B64.encode(data.as_slice()));
|
||||
let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0);
|
||||
|
||||
if !data_ops.is_empty()
|
||||
&& (data_ops.len() >= MAX_BATCH_OPS
|
||||
|| batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES)
|
||||
{
|
||||
fire_batch(
|
||||
&sems,
|
||||
&fronter,
|
||||
std::mem::take(&mut data_ops),
|
||||
std::mem::take(&mut data_replies),
|
||||
)
|
||||
.await;
|
||||
batch_payload_bytes = 0;
|
||||
}
|
||||
|
||||
let idx = data_ops.len();
|
||||
data_ops.push(BatchOp {
|
||||
op: "connect_data".into(),
|
||||
let op_bytes = encoded_len(data.len());
|
||||
let op = PendingOp {
|
||||
op: "connect_data",
|
||||
sid: None,
|
||||
host: Some(host),
|
||||
port: Some(port),
|
||||
d: encoded,
|
||||
});
|
||||
data_replies.push((idx, reply));
|
||||
batch_payload_bytes += op_bytes;
|
||||
data: Some(data),
|
||||
encode_empty: true,
|
||||
};
|
||||
accum.push_or_fire(op, op_bytes, reply, &sems, &fronter).await;
|
||||
}
|
||||
MuxMsg::Data { sid, data, reply } => {
|
||||
let encoded = if data.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(B64.encode(&data))
|
||||
};
|
||||
let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0);
|
||||
|
||||
// If adding this op would exceed limits, fire current
|
||||
// batch first and start a new one.
|
||||
if !data_ops.is_empty()
|
||||
&& (data_ops.len() >= MAX_BATCH_OPS
|
||||
|| batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES)
|
||||
{
|
||||
fire_batch(
|
||||
&sems,
|
||||
&fronter,
|
||||
std::mem::take(&mut data_ops),
|
||||
std::mem::take(&mut data_replies),
|
||||
)
|
||||
.await;
|
||||
batch_payload_bytes = 0;
|
||||
}
|
||||
|
||||
let idx = data_ops.len();
|
||||
data_ops.push(BatchOp {
|
||||
op: "data".into(),
|
||||
let op_bytes = encoded_len(data.len());
|
||||
let op = PendingOp {
|
||||
op: "data",
|
||||
sid: Some(sid),
|
||||
host: None,
|
||||
port: None,
|
||||
d: encoded,
|
||||
});
|
||||
data_replies.push((idx, reply));
|
||||
batch_payload_bytes += op_bytes;
|
||||
data: if data.is_empty() { None } else { Some(data) },
|
||||
encode_empty: false,
|
||||
};
|
||||
accum.push_or_fire(op, op_bytes, reply, &sems, &fronter).await;
|
||||
}
|
||||
MuxMsg::UdpOpen {
|
||||
host,
|
||||
@@ -717,70 +700,28 @@ async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>, c
|
||||
data,
|
||||
reply,
|
||||
} => {
|
||||
let encoded = if data.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(B64.encode(&data))
|
||||
};
|
||||
let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0);
|
||||
|
||||
if !data_ops.is_empty()
|
||||
&& (data_ops.len() >= MAX_BATCH_OPS
|
||||
|| batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES)
|
||||
{
|
||||
fire_batch(
|
||||
&sems,
|
||||
&fronter,
|
||||
std::mem::take(&mut data_ops),
|
||||
std::mem::take(&mut data_replies),
|
||||
)
|
||||
.await;
|
||||
batch_payload_bytes = 0;
|
||||
}
|
||||
|
||||
let idx = data_ops.len();
|
||||
data_ops.push(BatchOp {
|
||||
op: "udp_open".into(),
|
||||
let op_bytes = encoded_len(data.len());
|
||||
let op = PendingOp {
|
||||
op: "udp_open",
|
||||
sid: None,
|
||||
host: Some(host),
|
||||
port: Some(port),
|
||||
d: encoded,
|
||||
});
|
||||
data_replies.push((idx, reply));
|
||||
batch_payload_bytes += op_bytes;
|
||||
data: if data.is_empty() { None } else { Some(data) },
|
||||
encode_empty: false,
|
||||
};
|
||||
accum.push_or_fire(op, op_bytes, reply, &sems, &fronter).await;
|
||||
}
|
||||
MuxMsg::UdpData { sid, data, reply } => {
|
||||
let encoded = if data.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(B64.encode(&data))
|
||||
};
|
||||
let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0);
|
||||
|
||||
if !data_ops.is_empty()
|
||||
&& (data_ops.len() >= MAX_BATCH_OPS
|
||||
|| batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES)
|
||||
{
|
||||
fire_batch(
|
||||
&sems,
|
||||
&fronter,
|
||||
std::mem::take(&mut data_ops),
|
||||
std::mem::take(&mut data_replies),
|
||||
)
|
||||
.await;
|
||||
batch_payload_bytes = 0;
|
||||
}
|
||||
|
||||
let idx = data_ops.len();
|
||||
data_ops.push(BatchOp {
|
||||
op: "udp_data".into(),
|
||||
let op_bytes = encoded_len(data.len());
|
||||
let op = PendingOp {
|
||||
op: "udp_data",
|
||||
sid: Some(sid),
|
||||
host: None,
|
||||
port: None,
|
||||
d: encoded,
|
||||
});
|
||||
data_replies.push((idx, reply));
|
||||
batch_payload_bytes += op_bytes;
|
||||
data: if data.is_empty() { None } else { Some(data) },
|
||||
encode_empty: false,
|
||||
};
|
||||
accum.push_or_fire(op, op_bytes, reply, &sems, &fronter).await;
|
||||
}
|
||||
MuxMsg::Close { sid } => {
|
||||
close_sids.push(sid);
|
||||
@@ -788,21 +729,120 @@ async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>, c
|
||||
}
|
||||
}
|
||||
|
||||
// `close` ops piggyback on whatever batch we're about to fire — no
|
||||
// reply channel, no payload, just tell tunnel-node to drop the sid.
|
||||
for sid in close_sids {
|
||||
data_ops.push(BatchOp {
|
||||
op: "close".into(),
|
||||
accum.pending_ops.push(PendingOp {
|
||||
op: "close",
|
||||
sid: Some(sid),
|
||||
host: None,
|
||||
port: None,
|
||||
d: None,
|
||||
data: None,
|
||||
encode_empty: false,
|
||||
});
|
||||
}
|
||||
|
||||
if data_ops.is_empty() {
|
||||
if accum.pending_ops.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
fire_batch(&sems, &fronter, data_ops, data_replies).await;
|
||||
fire_batch(&sems, &fronter, accum.pending_ops, accum.data_replies).await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-iteration accumulator for `mux_loop`. Owns the three fields that
|
||||
/// the data-bearing arms used to mutate in lockstep, with a single
|
||||
/// `push_or_fire` entry point so the cap-then-push pattern lives in one
|
||||
/// place instead of being copy-pasted into every arm.
|
||||
struct BatchAccum {
|
||||
pending_ops: Vec<PendingOp>,
|
||||
data_replies: Vec<(usize, BatchedReply)>,
|
||||
payload_bytes: usize,
|
||||
}
|
||||
|
||||
impl BatchAccum {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
pending_ops: Vec::new(),
|
||||
data_replies: Vec::new(),
|
||||
payload_bytes: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Append `op` (with its `reply` channel and pre-computed `op_bytes`),
|
||||
/// firing the current accumulator first if `op` would push us past
|
||||
/// `MAX_BATCH_OPS` or `MAX_BATCH_PAYLOAD_BYTES`. After a fire the
|
||||
/// accumulator is fresh for the new op.
|
||||
async fn push_or_fire(
|
||||
&mut self,
|
||||
op: PendingOp,
|
||||
op_bytes: usize,
|
||||
reply: BatchedReply,
|
||||
sems: &Arc<HashMap<String, Arc<Semaphore>>>,
|
||||
fronter: &Arc<DomainFronter>,
|
||||
) {
|
||||
if should_fire(self.pending_ops.len(), self.payload_bytes, op_bytes) {
|
||||
fire_batch(
|
||||
sems,
|
||||
fronter,
|
||||
std::mem::take(&mut self.pending_ops),
|
||||
std::mem::take(&mut self.data_replies),
|
||||
)
|
||||
.await;
|
||||
self.payload_bytes = 0;
|
||||
}
|
||||
let idx = self.pending_ops.len();
|
||||
self.pending_ops.push(op);
|
||||
self.data_replies.push((idx, reply));
|
||||
self.payload_bytes += op_bytes;
|
||||
}
|
||||
}
|
||||
|
||||
/// Threshold predicate for `BatchAccum::push_or_fire`: would adding an
|
||||
/// op of `op_bytes` to a batch already holding `pending_len` ops and
|
||||
/// `payload_bytes` of base64 cross either the per-batch op cap or
|
||||
/// the payload-size cap?
|
||||
///
|
||||
/// Extracted from the inline `if` so the tunable boundary — including
|
||||
/// the "first op never fires" rule (`pending_len == 0`) — has direct
|
||||
/// unit-test coverage without spinning up a real `fire_batch`.
|
||||
///
|
||||
/// `saturating_add` keeps the helper's contract self-contained: a
|
||||
/// pathological `op_bytes` near `usize::MAX` clamps to "yes, fire"
|
||||
/// rather than wrapping around and silently letting an oversized op
|
||||
/// slip past the cap. Today's callers only feed `encoded_len(n)` on
|
||||
/// reasonable buffer sizes, but the predicate is the wrong place to
|
||||
/// rely on caller bounds.
|
||||
fn should_fire(pending_len: usize, payload_bytes: usize, op_bytes: usize) -> bool {
|
||||
pending_len > 0
|
||||
&& (pending_len >= MAX_BATCH_OPS
|
||||
|| payload_bytes.saturating_add(op_bytes) > MAX_BATCH_PAYLOAD_BYTES)
|
||||
}
|
||||
|
||||
/// Exact base64-encoded length of `n` raw bytes (standard padding):
|
||||
/// `((n + 2) / 3) * 4`. Used by `mux_loop` to enforce
|
||||
/// `MAX_BATCH_PAYLOAD_BYTES` without doing the actual encoding inline —
|
||||
/// that work now happens in `fire_batch`'s spawned task.
|
||||
fn encoded_len(n: usize) -> usize {
|
||||
n.div_ceil(3) * 4
|
||||
}
|
||||
|
||||
/// Build the wire-shape `BatchOp` from an internal `PendingOp`. Free
|
||||
/// function so the encoding contract — non-empty data → encoded,
|
||||
/// empty connect_data → `Some("")`, anything else empty → `None` — is
|
||||
/// directly testable without spinning up the mux loop.
|
||||
fn encode_pending(p: PendingOp) -> BatchOp {
|
||||
let d = match (&p.data, p.encode_empty) {
|
||||
(Some(b), _) if !b.is_empty() => Some(B64.encode(b)),
|
||||
(Some(_), true) => Some(String::new()),
|
||||
_ => None,
|
||||
};
|
||||
BatchOp {
|
||||
op: p.op.into(),
|
||||
sid: p.sid,
|
||||
host: p.host,
|
||||
port: p.port,
|
||||
d,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -815,7 +855,7 @@ async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>, c
|
||||
async fn fire_batch(
|
||||
sems: &Arc<HashMap<String, Arc<Semaphore>>>,
|
||||
fronter: &Arc<DomainFronter>,
|
||||
data_ops: Vec<BatchOp>,
|
||||
pending_ops: Vec<PendingOp>,
|
||||
data_replies: Vec<(usize, BatchedReply)>,
|
||||
) {
|
||||
let script_id = fronter.next_script_id();
|
||||
@@ -829,7 +869,13 @@ async fn fire_batch(
|
||||
tokio::spawn(async move {
|
||||
let _permit = permit;
|
||||
let t0 = std::time::Instant::now();
|
||||
let n_ops = data_ops.len();
|
||||
let n_ops = pending_ops.len();
|
||||
|
||||
// Encode payloads to base64 here, off the single mux thread.
|
||||
// With 50 ops × 64 KB this is up to ~3 MB of work; doing it on
|
||||
// the mux task previously serialized every op behind whichever
|
||||
// batch was currently encoding.
|
||||
let data_ops: Vec<BatchOp> = pending_ops.into_iter().map(encode_pending).collect();
|
||||
|
||||
// Bounded-wait: if the batch takes longer than the configured
|
||||
// batch timeout (Config::request_timeout_secs), all sessions in
|
||||
@@ -985,14 +1031,13 @@ pub async fn tunnel_connection(
|
||||
mux.record_preread_skip_port(port);
|
||||
None
|
||||
} else {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut buf = BytesMut::with_capacity(65536);
|
||||
let t0 = Instant::now();
|
||||
match tokio::time::timeout(CLIENT_FIRST_DATA_WAIT, sock.read(&mut buf)).await {
|
||||
match tokio::time::timeout(CLIENT_FIRST_DATA_WAIT, sock.read_buf(&mut buf)).await {
|
||||
Ok(Ok(0)) => return Ok(()),
|
||||
Ok(Ok(n)) => {
|
||||
Ok(Ok(_)) => {
|
||||
mux.record_preread_win(port, t0.elapsed());
|
||||
buf.truncate(n);
|
||||
Some(Arc::new(buf))
|
||||
Some(buf.freeze())
|
||||
}
|
||||
Ok(Err(e)) => return Err(e),
|
||||
Err(_) => {
|
||||
@@ -1008,14 +1053,10 @@ pub async fn tunnel_connection(
|
||||
ConnectDataOutcome::Unsupported => {
|
||||
mux.mark_connect_data_unsupported();
|
||||
let sid = connect_plain(host, port, mux).await?;
|
||||
// Recover the buffered ClientHello from the Arc so the
|
||||
// first tunnel_loop iteration can replay it. The mux task
|
||||
// may still hold the other ref during the unsupported
|
||||
// reply's settle window — fall back to a clone in that
|
||||
// race (rare; the reply path drops its ref before we
|
||||
// reach here in practice).
|
||||
let bytes = Arc::try_unwrap(data).unwrap_or_else(|a| (*a).clone());
|
||||
(sid, None, Some(bytes))
|
||||
// Replay the buffered ClientHello on the first tunnel_loop
|
||||
// iteration. `Bytes::clone()` is a cheap Arc bump — no
|
||||
// copy of the 64 KB buffer.
|
||||
(sid, None, Some(data))
|
||||
}
|
||||
},
|
||||
None => (connect_plain(host, port, mux).await?, None, None),
|
||||
@@ -1107,7 +1148,7 @@ async fn connect_plain(host: &str, port: u16, mux: &Arc<TunnelMux>) -> std::io::
|
||||
async fn connect_with_initial_data(
|
||||
host: &str,
|
||||
port: u16,
|
||||
data: Arc<Vec<u8>>,
|
||||
data: Bytes,
|
||||
mux: &Arc<TunnelMux>,
|
||||
) -> std::io::Result<ConnectDataOutcome> {
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
@@ -1212,10 +1253,30 @@ async fn tunnel_loop(
|
||||
sock: &mut TcpStream,
|
||||
sid: &str,
|
||||
mux: &Arc<TunnelMux>,
|
||||
mut pending_client_data: Option<Vec<u8>>,
|
||||
mut pending_client_data: Option<Bytes>,
|
||||
) -> std::io::Result<()> {
|
||||
let (mut reader, mut writer) = sock.split();
|
||||
let mut buf = vec![0u8; 65536];
|
||||
// `BytesMut` + `read_buf` + a per-read decision between
|
||||
// `split().freeze()` (zero-copy) and `copy_from_slice` + `clear`
|
||||
// (right-sized copy, buffer reused).
|
||||
//
|
||||
// Why the split decision: `bytes` 1.x refcounts the *whole*
|
||||
// backing allocation, so a frozen `Bytes` from a partial read
|
||||
// pins all `READ_CHUNK` bytes until it drops. Under semaphore
|
||||
// saturation or reply timeouts, dozens of small TLS records or
|
||||
// HTTP/2 frames can each retain ~64 KB instead of their actual
|
||||
// payload size — order-of-magnitude memory regression on
|
||||
// constrained targets (router builds with 64 MB RAM).
|
||||
//
|
||||
// Threshold: at ≥ half-buffer the saved memcpy outweighs the
|
||||
// wasted slack, and these reads are typically streaming bulk
|
||||
// transfer where the `Bytes` flushes through the mux quickly.
|
||||
// Below that, copy out and `clear()` so the same allocation
|
||||
// serves the next read — equivalent memory profile to the old
|
||||
// `vec![0u8; 65536]` + `to_vec()` code on small-read workloads.
|
||||
const READ_CHUNK: usize = 65536;
|
||||
const ZERO_COPY_THRESHOLD: usize = READ_CHUNK / 2;
|
||||
let mut buf = BytesMut::with_capacity(READ_CHUNK);
|
||||
let mut consecutive_empty = 0u32;
|
||||
|
||||
loop {
|
||||
@@ -1254,11 +1315,28 @@ async fn tunnel_loop(
|
||||
(true, _) => Duration::from_secs(30),
|
||||
};
|
||||
|
||||
match tokio::time::timeout(read_timeout, reader.read(&mut buf)).await {
|
||||
buf.reserve(READ_CHUNK);
|
||||
match tokio::time::timeout(read_timeout, reader.read_buf(&mut buf)).await {
|
||||
Ok(Ok(0)) => break,
|
||||
Ok(Ok(n)) => {
|
||||
consecutive_empty = 0;
|
||||
Some(buf[..n].to_vec())
|
||||
if n >= ZERO_COPY_THRESHOLD {
|
||||
// Big read: split off the filled region. The
|
||||
// frozen `Bytes` is at-least-half-full, so the
|
||||
// saved 64 KB memcpy outweighs the brief
|
||||
// retention until the mux drains.
|
||||
Some(buf.split().freeze())
|
||||
} else {
|
||||
// Small read: copy out a payload-sized `Bytes`
|
||||
// and `clear()` so the buffer is reused on the
|
||||
// next iter (no `reserve` allocation needed
|
||||
// because the alloc stays uniquely owned).
|
||||
// Bounds retention to actual data even when
|
||||
// the mux is backpressured.
|
||||
let owned = Bytes::copy_from_slice(&buf[..n]);
|
||||
buf.clear();
|
||||
Some(owned)
|
||||
}
|
||||
}
|
||||
Ok(Err(_)) => break,
|
||||
Err(_) => None,
|
||||
@@ -1275,7 +1353,7 @@ async fn tunnel_loop(
|
||||
continue;
|
||||
}
|
||||
|
||||
let data = client_data.unwrap_or_default();
|
||||
let data = client_data.unwrap_or_else(Bytes::new);
|
||||
let was_empty_poll = data.is_empty();
|
||||
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
@@ -1664,7 +1742,7 @@ mod tests {
|
||||
let mut server_side = accept.await.unwrap();
|
||||
|
||||
let (mux, mut rx) = mux_for_test();
|
||||
let pending = Some(b"CLIENTHELLO".to_vec());
|
||||
let pending = Some(Bytes::from_static(b"CLIENTHELLO"));
|
||||
|
||||
let loop_handle = tokio::spawn({
|
||||
let mux = mux.clone();
|
||||
@@ -1907,6 +1985,199 @@ mod tests {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_fire_first_op_never_fires() {
|
||||
// Empty accumulator: even a single op larger than the payload cap
|
||||
// must not fire — there's nothing to fire yet, and the op gets
|
||||
// added (it will simply be the only op in the next batch).
|
||||
assert!(!should_fire(0, 0, 0));
|
||||
assert!(!should_fire(0, 0, MAX_BATCH_PAYLOAD_BYTES + 1_000_000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_fire_at_max_ops_threshold() {
|
||||
// 49 already-queued ops + 50th: still fits (boundary is `>=`).
|
||||
assert!(!should_fire(MAX_BATCH_OPS - 1, 0, 100));
|
||||
// 50 already-queued ops + 51st: must fire.
|
||||
assert!(should_fire(MAX_BATCH_OPS, 0, 100));
|
||||
// Well past the cap: must fire.
|
||||
assert!(should_fire(MAX_BATCH_OPS + 5, 0, 100));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_fire_when_payload_would_exceed_cap() {
|
||||
// Exactly at the cap is fine — strict `>`.
|
||||
assert!(!should_fire(
|
||||
10,
|
||||
MAX_BATCH_PAYLOAD_BYTES - 100,
|
||||
100,
|
||||
));
|
||||
// One byte over: fire.
|
||||
assert!(should_fire(
|
||||
10,
|
||||
MAX_BATCH_PAYLOAD_BYTES - 100,
|
||||
101,
|
||||
));
|
||||
// Sum overflow well past the cap: fire.
|
||||
assert!(should_fire(
|
||||
10,
|
||||
MAX_BATCH_PAYLOAD_BYTES,
|
||||
1,
|
||||
));
|
||||
}
|
||||
|
||||
/// Reply indices must point at the slot the op occupies *within its
|
||||
/// batch*. Pre-flush ops are 0..N-1 in batch A; post-flush ops
|
||||
/// restart at 0 in batch B. If this regresses, `fire_batch`'s
|
||||
/// `batch_resp.r.get(idx)` lookup hands the wrong response (or
|
||||
/// `None`) to the wrong session — silent data corruption that
|
||||
/// the encode-layer tests can't catch.
|
||||
#[tokio::test]
|
||||
async fn batch_accum_reindexes_after_flush() {
|
||||
// Stand-alone helper that mirrors `push_or_fire`'s push step
|
||||
// without the fire_batch call — lets us simulate a flush with
|
||||
// `mem::take` and assert the post-flush indexing without
|
||||
// mocking the whole tunnel_request stack.
|
||||
fn push_no_fire(
|
||||
accum: &mut BatchAccum,
|
||||
op: PendingOp,
|
||||
op_bytes: usize,
|
||||
reply: BatchedReply,
|
||||
) {
|
||||
let idx = accum.pending_ops.len();
|
||||
accum.pending_ops.push(op);
|
||||
accum.data_replies.push((idx, reply));
|
||||
accum.payload_bytes += op_bytes;
|
||||
}
|
||||
|
||||
let mk_op = |sid: &str| PendingOp {
|
||||
op: "data",
|
||||
sid: Some(sid.into()),
|
||||
host: None,
|
||||
port: None,
|
||||
data: Some(Bytes::from_static(b"x")),
|
||||
encode_empty: false,
|
||||
};
|
||||
let mk_reply = || oneshot::channel::<Result<(TunnelResponse, String), String>>().0;
|
||||
|
||||
let mut accum = BatchAccum::new();
|
||||
|
||||
// Batch A: 3 ops at indices 0, 1, 2.
|
||||
push_no_fire(&mut accum, mk_op("a0"), 4, mk_reply());
|
||||
push_no_fire(&mut accum, mk_op("a1"), 4, mk_reply());
|
||||
push_no_fire(&mut accum, mk_op("a2"), 4, mk_reply());
|
||||
assert_eq!(accum.pending_ops.len(), 3);
|
||||
assert_eq!(
|
||||
accum.data_replies.iter().map(|(i, _)| *i).collect::<Vec<_>>(),
|
||||
vec![0, 1, 2],
|
||||
);
|
||||
assert_eq!(accum.payload_bytes, 12);
|
||||
|
||||
// Simulate the flush: take the queued state and reset the byte
|
||||
// counter (matches what `push_or_fire` does after `fire_batch`).
|
||||
let _flushed_ops = std::mem::take(&mut accum.pending_ops);
|
||||
let _flushed_replies = std::mem::take(&mut accum.data_replies);
|
||||
accum.payload_bytes = 0;
|
||||
|
||||
// Batch B: 2 ops, indices restart at 0.
|
||||
push_no_fire(&mut accum, mk_op("b0"), 4, mk_reply());
|
||||
push_no_fire(&mut accum, mk_op("b1"), 4, mk_reply());
|
||||
assert_eq!(accum.pending_ops.len(), 2);
|
||||
assert_eq!(
|
||||
accum.data_replies.iter().map(|(i, _)| *i).collect::<Vec<_>>(),
|
||||
vec![0, 1],
|
||||
"post-flush indices must restart at 0 — otherwise fire_batch's \
|
||||
batch_resp.r.get(idx) returns None and every session in the \
|
||||
second batch sees a missing-response error"
|
||||
);
|
||||
assert_eq!(accum.payload_bytes, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_pending_data_op_with_payload_emits_base64() {
|
||||
let op = PendingOp {
|
||||
op: "data",
|
||||
sid: Some("sid-1".into()),
|
||||
host: None,
|
||||
port: None,
|
||||
data: Some(Bytes::from_static(b"hello")),
|
||||
encode_empty: false,
|
||||
};
|
||||
let b = encode_pending(op);
|
||||
assert_eq!(b.op, "data");
|
||||
assert_eq!(b.sid.as_deref(), Some("sid-1"));
|
||||
assert_eq!(b.d.as_deref(), Some(B64.encode(b"hello").as_str()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_pending_omits_d_for_empty_polls_and_close() {
|
||||
// Empty-poll Data: mux_loop converts empty Bytes to data: None.
|
||||
let empty_poll = PendingOp {
|
||||
op: "data",
|
||||
sid: Some("sid-2".into()),
|
||||
host: None,
|
||||
port: None,
|
||||
data: None,
|
||||
encode_empty: false,
|
||||
};
|
||||
assert!(encode_pending(empty_poll).d.is_none());
|
||||
|
||||
// UDP poll with no payload: same shape.
|
||||
let udp_poll = PendingOp {
|
||||
op: "udp_data",
|
||||
sid: Some("sid-3".into()),
|
||||
host: None,
|
||||
port: None,
|
||||
data: None,
|
||||
encode_empty: false,
|
||||
};
|
||||
assert!(encode_pending(udp_poll).d.is_none());
|
||||
|
||||
// Close has no data and no reply — `d` must stay omitted.
|
||||
let close = PendingOp {
|
||||
op: "close",
|
||||
sid: Some("sid-4".into()),
|
||||
host: None,
|
||||
port: None,
|
||||
data: None,
|
||||
encode_empty: false,
|
||||
};
|
||||
assert!(encode_pending(close).d.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_pending_connect_data_emits_empty_string_when_data_is_empty() {
|
||||
// Defensive: ConnectData's wire contract is that `d` is always
|
||||
// present (its presence is the signal that the caller is opting
|
||||
// into the bundled-first-bytes flow). If an empty Bytes ever
|
||||
// reaches the encoder, we must serialize `d: ""` not omit it.
|
||||
let op = PendingOp {
|
||||
op: "connect_data",
|
||||
sid: None,
|
||||
host: Some("example.com".into()),
|
||||
port: Some(443),
|
||||
data: Some(Bytes::new()),
|
||||
encode_empty: true,
|
||||
};
|
||||
let b = encode_pending(op);
|
||||
assert_eq!(b.op, "connect_data");
|
||||
assert_eq!(b.d.as_deref(), Some(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn encode_pending_connect_data_with_payload_encodes_normally() {
|
||||
let op = PendingOp {
|
||||
op: "connect_data",
|
||||
sid: None,
|
||||
host: Some("example.com".into()),
|
||||
port: Some(443),
|
||||
data: Some(Bytes::from_static(b"\x16\x03\x01")), // ClientHello prefix
|
||||
encode_empty: true,
|
||||
};
|
||||
let b = encode_pending(op);
|
||||
assert_eq!(b.d.as_deref(), Some(B64.encode(b"\x16\x03\x01").as_str()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preread_counters_track_each_outcome() {
|
||||
let (mux, _rx) = mux_for_test();
|
||||
|
||||
Reference in New Issue
Block a user