perf(tunnel): zero-copy mux + base64 off mux thread (#881)

Performance refactor of full-tunnel mode hot path. No wire-protocol or
behavior changes — internal data flow only.

**1. Zero-copy reads via `Bytes`/`BytesMut`**
`tunnel_loop` and the SOCKS5 UDP receive loop drop their per-iteration
`Vec::to_vec()` copies. `MuxMsg::{ConnectData,Data,UdpOpen,UdpData}` now
carry `Bytes` instead of `Vec<u8>`/`Arc<Vec<u8>>`; the `Arc::try_unwrap`
dance for `pending_client_data` is gone (Bytes is already Arc-backed).

TCP path is threshold-based to avoid the obvious memory regression:
- n ≥ 32 KB: `BytesMut::split().freeze()` — saves the 64 KB memcpy on
  hot downloads.
- n < 32 KB: `Bytes::copy_from_slice` + `buf.clear()` — payload-sized
  retention. Without this split, a queued tiny TLS record would refcount-
  pin the full 64 KB recv buffer (worst case ~96 MB on a backpressured
  tunnel).

UDP path: fixed `Vec<u8>` recv buffer + `Bytes::copy_from_slice` after
the 9 KB MAX_UDP_PAYLOAD_BYTES guard. `parse_socks5_udp_packet` split
into `_offsets` + `&[u8]` wrapper so callers stay on the reusable buffer.

**2. Base64 encoding moved off the single mux thread**
New internal `PendingOp { data: Option<Bytes>, encode_empty: bool }`
flows through `mux_loop` with raw bytes. Actual `B64.encode(...)` runs
in `fire_batch`'s spawned task, after the per-deployment semaphore. Up
to ~3 MB of encoding per batch (50 ops × 64 KB) no longer serializes
the single mux task.

**3. Code quality**
- `BatchAccum::push_or_fire` collapses 4× ~25-line match arms → ~10 each.
- `should_fire(pending_len, payload_bytes, op_bytes)` extracted with
  `saturating_add` for a self-contained contract.
- `encode_pending(p) -> BatchOp` extracted as a free function so the
  encoding contract is directly testable.

**Tests:** 208/208 (was 200, +8 new):
- `encode_pending_*` × 4 — base64-encode contract per MuxMsg variant
- `should_fire_*` × 3 — first-op, MAX_BATCH_OPS boundary, payload cap
- `batch_accum_reindexes_after_flush` — regression test for post-flush
  reply index lookup in `fire_batch`

**Public API:** `TunnelMux::udp_open` and `udp_data` now take
`data: impl Into<Bytes>` instead of `Vec<u8>`. Existing call sites
keep compiling.

Reviewed via Anthropic Claude.

Co-Authored-By: dazzling-no-more <noreply@github.com>
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
dazzling-no-more
2026-05-08 06:25:48 +04:00
committed by GitHub
parent 624914241a
commit 54552bbdac
2 changed files with 469 additions and 162 deletions
+48 -12
View File
@@ -3,6 +3,7 @@ use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::{TcpListener, TcpStream, UdpSocket};
use tokio::sync::{mpsc, Mutex};
@@ -965,7 +966,7 @@ struct SocksUdpTarget {
/// to abort mid-await.
struct UdpRelaySession {
sid: String,
uplink: mpsc::Sender<Vec<u8>>,
uplink: mpsc::Sender<Bytes>,
}
/// All per-ASSOCIATE UDP relay state behind a single mutex so insertion
@@ -991,7 +992,7 @@ impl UdpRelayState {
}
}
fn get_uplink(&self, target: &SocksUdpTarget) -> Option<mpsc::Sender<Vec<u8>>> {
fn get_uplink(&self, target: &SocksUdpTarget) -> Option<mpsc::Sender<Bytes>> {
self.sessions.get(target).map(|s| s.uplink.clone())
}
@@ -1118,7 +1119,15 @@ async fn handle_socks5_udp_associate(
client_peer_ip
);
let mut buf = vec![0u8; SOCKS5_UDP_RECV_BUF_BYTES];
// Fixed reusable recv buffer. We deliberately don't go the
// `BytesMut::split().freeze()` route here even though `tunnel_loop`
// does: in TCP the read region IS the payload, but UDP always
// slices the SOCKS5 header off, so we'd be copying out anyway —
// and a frozen `Bytes` from the recv buf would refcount-pin the
// full ~65 KB allocation behind a tiny DNS reply, ballooning
// memory under bursts. Right-sized `Bytes::copy_from_slice` on
// accepted payloads keeps retention proportional to actual data.
let mut recv_buf = vec![0u8; SOCKS5_UDP_RECV_BUF_BYTES];
let mut control_buf = [0u8; 1];
let mut client_addr: Option<SocketAddr> = None;
let state: Arc<Mutex<UdpRelayState>> = Arc::new(Mutex::new(UdpRelayState::new()));
@@ -1134,7 +1143,7 @@ async fn handle_socks5_udp_associate(
loop {
tokio::select! {
recv = udp.recv_from(&mut buf) => {
recv = udp.recv_from(&mut recv_buf) => {
let (n, peer) = match recv {
Ok(v) => v,
Err(e) => {
@@ -1142,6 +1151,7 @@ async fn handle_socks5_udp_associate(
break;
}
};
// Source-IP check: anything not from the SOCKS5 client's
// host is dropped silently.
if peer.ip() != client_peer_ip {
@@ -1162,9 +1172,10 @@ async fn handle_socks5_udp_associate(
// can race one bad packet to DoS the legitimate client
// (whose real datagram, sent from a different ephemeral
// port, would then be silently rejected).
let Some((target, payload)) = parse_socks5_udp_packet(&buf[..n]) else {
let Some((target, payload_off)) = parse_socks5_udp_packet_offsets(&recv_buf[..n]) else {
continue;
};
let payload_slice = &recv_buf[payload_off..n];
// Issue #213: client-side QUIC block. UDP/443 is
// HTTP/3 — drop the datagram silently so the client
@@ -1206,19 +1217,26 @@ async fn handle_socks5_udp_associate(
// the mux. Each datagram costs ~payload * 1.33 in the
// batched JSON envelope plus tunnel-node CPU; uncapped,
// a runaway client can exhaust Apps Script quota.
if payload.len() > MAX_UDP_PAYLOAD_BYTES {
if payload_slice.len() > MAX_UDP_PAYLOAD_BYTES {
oversized_dropped += 1;
if oversized_dropped == 1 || oversized_dropped.is_multiple_of(100) {
tracing::debug!(
"udp datagram dropped: {} B > {} B (count={})",
payload.len(),
payload_slice.len(),
MAX_UDP_PAYLOAD_BYTES,
oversized_dropped,
);
}
continue;
}
let payload = payload.to_vec();
// Right-sized copy: the queued/in-flight payload owns its
// own allocation, so the recv buffer can be reused on the
// next iteration without keeping every queued datagram
// alive. Sized to the actual payload (≤ MAX_UDP_PAYLOAD_BYTES
// = 9 KB after the guard above), not the full ~65 KB recv
// buffer.
let payload = Bytes::copy_from_slice(payload_slice);
// Fast path: existing session — push payload onto its
// bounded uplink queue, drop on overflow (UDP semantics).
@@ -1292,7 +1310,7 @@ async fn handle_socks5_udp_associate(
continue;
}
let (uplink_tx, uplink_rx) = mpsc::channel::<Vec<u8>>(UDP_UPLINK_QUEUE);
let (uplink_tx, uplink_rx) = mpsc::channel::<Bytes>(UDP_UPLINK_QUEUE);
let task_mux = mux.clone();
let task_udp = udp.clone();
let task_target = target.clone();
@@ -1365,7 +1383,7 @@ async fn udp_session_task(
sid: String,
target: SocksUdpTarget,
client_addr: SocketAddr,
mut uplink_rx: mpsc::Receiver<Vec<u8>>,
mut uplink_rx: mpsc::Receiver<Bytes>,
) {
let mut backoff = UDP_INITIAL_POLL_DELAY;
loop {
@@ -1473,7 +1491,20 @@ async fn write_socks5_reply(
sock.flush().await
}
fn parse_socks5_udp_packet(buf: &[u8]) -> Option<(SocksUdpTarget, &[u8])> {
/// Parse the SOCKS5 UDP frame header and return the target plus the byte
/// offset at which the payload starts. Splitting "structure parsing"
/// from "give me a payload slice" lets the recv hot path stay on a
/// fixed reusable `Vec<u8>` buffer and only allocate a right-sized
/// `Bytes::copy_from_slice(&recv_buf[off..n])` for accepted payloads
/// (after the size guard). DO NOT change this back to a zero-copy
/// `Bytes::slice` path: that was tried and reverted because slicing
/// the recv buffer with `bytes` 1.x refcounts the whole ~65 KB
/// allocation, so a queued tiny DNS reply pinned the full datagram-
/// sized buffer until it drained — burst retention regressed by
/// orders of magnitude on UDP-heavy workloads. The thin
/// `parse_socks5_udp_packet` wrapper below keeps existing `&[u8]`
/// callers (tests) working.
fn parse_socks5_udp_packet_offsets(buf: &[u8]) -> Option<(SocksUdpTarget, usize)> {
if buf.len() < 4 || buf[0] != 0 || buf[1] != 0 || buf[2] != 0 {
return None;
}
@@ -1528,10 +1559,15 @@ fn parse_socks5_udp_packet(buf: &[u8]) -> Option<(SocksUdpTarget, &[u8])> {
atyp,
addr,
},
&buf[pos..],
pos,
))
}
fn parse_socks5_udp_packet(buf: &[u8]) -> Option<(SocksUdpTarget, &[u8])> {
let (target, off) = parse_socks5_udp_packet_offsets(buf)?;
Some((target, &buf[off..]))
}
fn build_socks5_udp_packet(target: &SocksUdpTarget, payload: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(4 + target.addr.len() + 2 + payload.len() + 1);
out.extend_from_slice(&[0, 0, 0, target.atyp]);
+421 -150
View File
@@ -19,6 +19,7 @@ use std::time::{Duration, Instant};
use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine;
use bytes::{Bytes, BytesMut};
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot, Semaphore};
@@ -163,25 +164,26 @@ enum MuxMsg {
ConnectData {
host: String,
port: u16,
// Arc so the caller can hand the buffer to the mux AND keep a ref
// for the fallback path without an extra 64 KB copy per session.
data: Arc<Vec<u8>>,
// `Bytes` is internally Arc-backed, so the caller can cheaply
// clone() to keep its own reference for the unsupported-fallback
// replay path without an extra 64 KB copy per session.
data: Bytes,
reply: BatchedReply,
},
Data {
sid: String,
data: Vec<u8>,
data: Bytes,
reply: BatchedReply,
},
UdpOpen {
host: String,
port: u16,
data: Vec<u8>,
data: Bytes,
reply: BatchedReply,
},
UdpData {
sid: String,
data: Vec<u8>,
data: Bytes,
reply: BatchedReply,
},
Close {
@@ -189,6 +191,25 @@ enum MuxMsg {
},
}
/// Raw, not-yet-encoded form of a batch operation. Lives only inside
/// `mux_loop` and gets converted to `BatchOp` (with base64-encoded `d`)
/// inside `fire_batch`'s spawned task — keeping the encoding work off
/// the single mux thread, which previously had to base64 every op
/// inline before it could move on to the next message.
struct PendingOp {
op: &'static str,
sid: Option<String>,
host: Option<String>,
port: Option<u16>,
/// Raw payload. `None` for empty polls / opless ops; `Some` even
/// when empty preserves the connect_data shape (always emits `d`).
data: Option<Bytes>,
/// True for ops that must serialize `d` even when empty (currently
/// only `connect_data`, which uses presence of `d` as the signal
/// that the caller is opting into the bundled-first-bytes flow).
encode_empty: bool,
}
pub struct TunnelMux {
tx: mpsc::Sender<MuxMsg>,
/// Set to `true` after the first time the tunnel-node rejects
@@ -316,13 +337,13 @@ impl TunnelMux {
&self,
host: &str,
port: u16,
data: Vec<u8>,
data: impl Into<Bytes>,
) -> Result<TunnelResponse, String> {
let (reply_tx, reply_rx) = oneshot::channel();
self.send(MuxMsg::UdpOpen {
host: host.to_string(),
port,
data,
data: data.into(),
reply: reply_tx,
})
.await;
@@ -333,11 +354,15 @@ impl TunnelMux {
}
}
pub async fn udp_data(&self, sid: &str, data: Vec<u8>) -> Result<TunnelResponse, String> {
pub async fn udp_data(
&self,
sid: &str,
data: impl Into<Bytes>,
) -> Result<TunnelResponse, String> {
let (reply_tx, reply_rx) = oneshot::channel();
self.send(MuxMsg::UdpData {
sid: sid.to_string(),
data,
data: data.into(),
reply: reply_tx,
})
.await;
@@ -619,10 +644,8 @@ async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>, c
}
// Split: plain connects go parallel, data-bearing ops get batched.
let mut data_ops: Vec<BatchOp> = Vec::new();
let mut data_replies: Vec<(usize, BatchedReply)> = Vec::new();
let mut accum = BatchAccum::new();
let mut close_sids: Vec<String> = Vec::new();
let mut batch_payload_bytes: usize = 0;
for msg in msgs {
match msg {
@@ -648,68 +671,28 @@ async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>, c
data,
reply,
} => {
let encoded = Some(B64.encode(data.as_slice()));
let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0);
if !data_ops.is_empty()
&& (data_ops.len() >= MAX_BATCH_OPS
|| batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES)
{
fire_batch(
&sems,
&fronter,
std::mem::take(&mut data_ops),
std::mem::take(&mut data_replies),
)
.await;
batch_payload_bytes = 0;
}
let idx = data_ops.len();
data_ops.push(BatchOp {
op: "connect_data".into(),
let op_bytes = encoded_len(data.len());
let op = PendingOp {
op: "connect_data",
sid: None,
host: Some(host),
port: Some(port),
d: encoded,
});
data_replies.push((idx, reply));
batch_payload_bytes += op_bytes;
data: Some(data),
encode_empty: true,
};
accum.push_or_fire(op, op_bytes, reply, &sems, &fronter).await;
}
MuxMsg::Data { sid, data, reply } => {
let encoded = if data.is_empty() {
None
} else {
Some(B64.encode(&data))
};
let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0);
// If adding this op would exceed limits, fire current
// batch first and start a new one.
if !data_ops.is_empty()
&& (data_ops.len() >= MAX_BATCH_OPS
|| batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES)
{
fire_batch(
&sems,
&fronter,
std::mem::take(&mut data_ops),
std::mem::take(&mut data_replies),
)
.await;
batch_payload_bytes = 0;
}
let idx = data_ops.len();
data_ops.push(BatchOp {
op: "data".into(),
let op_bytes = encoded_len(data.len());
let op = PendingOp {
op: "data",
sid: Some(sid),
host: None,
port: None,
d: encoded,
});
data_replies.push((idx, reply));
batch_payload_bytes += op_bytes;
data: if data.is_empty() { None } else { Some(data) },
encode_empty: false,
};
accum.push_or_fire(op, op_bytes, reply, &sems, &fronter).await;
}
MuxMsg::UdpOpen {
host,
@@ -717,70 +700,28 @@ async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>, c
data,
reply,
} => {
let encoded = if data.is_empty() {
None
} else {
Some(B64.encode(&data))
};
let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0);
if !data_ops.is_empty()
&& (data_ops.len() >= MAX_BATCH_OPS
|| batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES)
{
fire_batch(
&sems,
&fronter,
std::mem::take(&mut data_ops),
std::mem::take(&mut data_replies),
)
.await;
batch_payload_bytes = 0;
}
let idx = data_ops.len();
data_ops.push(BatchOp {
op: "udp_open".into(),
let op_bytes = encoded_len(data.len());
let op = PendingOp {
op: "udp_open",
sid: None,
host: Some(host),
port: Some(port),
d: encoded,
});
data_replies.push((idx, reply));
batch_payload_bytes += op_bytes;
data: if data.is_empty() { None } else { Some(data) },
encode_empty: false,
};
accum.push_or_fire(op, op_bytes, reply, &sems, &fronter).await;
}
MuxMsg::UdpData { sid, data, reply } => {
let encoded = if data.is_empty() {
None
} else {
Some(B64.encode(&data))
};
let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0);
if !data_ops.is_empty()
&& (data_ops.len() >= MAX_BATCH_OPS
|| batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES)
{
fire_batch(
&sems,
&fronter,
std::mem::take(&mut data_ops),
std::mem::take(&mut data_replies),
)
.await;
batch_payload_bytes = 0;
}
let idx = data_ops.len();
data_ops.push(BatchOp {
op: "udp_data".into(),
let op_bytes = encoded_len(data.len());
let op = PendingOp {
op: "udp_data",
sid: Some(sid),
host: None,
port: None,
d: encoded,
});
data_replies.push((idx, reply));
batch_payload_bytes += op_bytes;
data: if data.is_empty() { None } else { Some(data) },
encode_empty: false,
};
accum.push_or_fire(op, op_bytes, reply, &sems, &fronter).await;
}
MuxMsg::Close { sid } => {
close_sids.push(sid);
@@ -788,21 +729,120 @@ async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>, c
}
}
// `close` ops piggyback on whatever batch we're about to fire — no
// reply channel, no payload, just tell tunnel-node to drop the sid.
for sid in close_sids {
data_ops.push(BatchOp {
op: "close".into(),
accum.pending_ops.push(PendingOp {
op: "close",
sid: Some(sid),
host: None,
port: None,
d: None,
data: None,
encode_empty: false,
});
}
if data_ops.is_empty() {
if accum.pending_ops.is_empty() {
continue;
}
fire_batch(&sems, &fronter, data_ops, data_replies).await;
fire_batch(&sems, &fronter, accum.pending_ops, accum.data_replies).await;
}
}
/// Per-iteration accumulator for `mux_loop`. Owns the three fields that
/// the data-bearing arms used to mutate in lockstep, with a single
/// `push_or_fire` entry point so the cap-then-push pattern lives in one
/// place instead of being copy-pasted into every arm.
struct BatchAccum {
pending_ops: Vec<PendingOp>,
data_replies: Vec<(usize, BatchedReply)>,
payload_bytes: usize,
}
impl BatchAccum {
fn new() -> Self {
Self {
pending_ops: Vec::new(),
data_replies: Vec::new(),
payload_bytes: 0,
}
}
/// Append `op` (with its `reply` channel and pre-computed `op_bytes`),
/// firing the current accumulator first if `op` would push us past
/// `MAX_BATCH_OPS` or `MAX_BATCH_PAYLOAD_BYTES`. After a fire the
/// accumulator is fresh for the new op.
async fn push_or_fire(
&mut self,
op: PendingOp,
op_bytes: usize,
reply: BatchedReply,
sems: &Arc<HashMap<String, Arc<Semaphore>>>,
fronter: &Arc<DomainFronter>,
) {
if should_fire(self.pending_ops.len(), self.payload_bytes, op_bytes) {
fire_batch(
sems,
fronter,
std::mem::take(&mut self.pending_ops),
std::mem::take(&mut self.data_replies),
)
.await;
self.payload_bytes = 0;
}
let idx = self.pending_ops.len();
self.pending_ops.push(op);
self.data_replies.push((idx, reply));
self.payload_bytes += op_bytes;
}
}
/// Threshold predicate for `BatchAccum::push_or_fire`: would adding an
/// op of `op_bytes` to a batch already holding `pending_len` ops and
/// `payload_bytes` of base64 cross either the per-batch op cap or
/// the payload-size cap?
///
/// Extracted from the inline `if` so the tunable boundary — including
/// the "first op never fires" rule (`pending_len == 0`) — has direct
/// unit-test coverage without spinning up a real `fire_batch`.
///
/// `saturating_add` keeps the helper's contract self-contained: a
/// pathological `op_bytes` near `usize::MAX` clamps to "yes, fire"
/// rather than wrapping around and silently letting an oversized op
/// slip past the cap. Today's callers only feed `encoded_len(n)` on
/// reasonable buffer sizes, but the predicate is the wrong place to
/// rely on caller bounds.
fn should_fire(pending_len: usize, payload_bytes: usize, op_bytes: usize) -> bool {
pending_len > 0
&& (pending_len >= MAX_BATCH_OPS
|| payload_bytes.saturating_add(op_bytes) > MAX_BATCH_PAYLOAD_BYTES)
}
/// Exact base64-encoded length of `n` raw bytes (standard padding):
/// `((n + 2) / 3) * 4`. Used by `mux_loop` to enforce
/// `MAX_BATCH_PAYLOAD_BYTES` without doing the actual encoding inline —
/// that work now happens in `fire_batch`'s spawned task.
fn encoded_len(n: usize) -> usize {
n.div_ceil(3) * 4
}
/// Build the wire-shape `BatchOp` from an internal `PendingOp`. Free
/// function so the encoding contract — non-empty data → encoded,
/// empty connect_data → `Some("")`, anything else empty → `None` — is
/// directly testable without spinning up the mux loop.
fn encode_pending(p: PendingOp) -> BatchOp {
let d = match (&p.data, p.encode_empty) {
(Some(b), _) if !b.is_empty() => Some(B64.encode(b)),
(Some(_), true) => Some(String::new()),
_ => None,
};
BatchOp {
op: p.op.into(),
sid: p.sid,
host: p.host,
port: p.port,
d,
}
}
@@ -815,7 +855,7 @@ async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>, c
async fn fire_batch(
sems: &Arc<HashMap<String, Arc<Semaphore>>>,
fronter: &Arc<DomainFronter>,
data_ops: Vec<BatchOp>,
pending_ops: Vec<PendingOp>,
data_replies: Vec<(usize, BatchedReply)>,
) {
let script_id = fronter.next_script_id();
@@ -829,7 +869,13 @@ async fn fire_batch(
tokio::spawn(async move {
let _permit = permit;
let t0 = std::time::Instant::now();
let n_ops = data_ops.len();
let n_ops = pending_ops.len();
// Encode payloads to base64 here, off the single mux thread.
// With 50 ops × 64 KB this is up to ~3 MB of work; doing it on
// the mux task previously serialized every op behind whichever
// batch was currently encoding.
let data_ops: Vec<BatchOp> = pending_ops.into_iter().map(encode_pending).collect();
// Bounded-wait: if the batch takes longer than the configured
// batch timeout (Config::request_timeout_secs), all sessions in
@@ -985,14 +1031,13 @@ pub async fn tunnel_connection(
mux.record_preread_skip_port(port);
None
} else {
let mut buf = vec![0u8; 65536];
let mut buf = BytesMut::with_capacity(65536);
let t0 = Instant::now();
match tokio::time::timeout(CLIENT_FIRST_DATA_WAIT, sock.read(&mut buf)).await {
match tokio::time::timeout(CLIENT_FIRST_DATA_WAIT, sock.read_buf(&mut buf)).await {
Ok(Ok(0)) => return Ok(()),
Ok(Ok(n)) => {
Ok(Ok(_)) => {
mux.record_preread_win(port, t0.elapsed());
buf.truncate(n);
Some(Arc::new(buf))
Some(buf.freeze())
}
Ok(Err(e)) => return Err(e),
Err(_) => {
@@ -1008,14 +1053,10 @@ pub async fn tunnel_connection(
ConnectDataOutcome::Unsupported => {
mux.mark_connect_data_unsupported();
let sid = connect_plain(host, port, mux).await?;
// Recover the buffered ClientHello from the Arc so the
// first tunnel_loop iteration can replay it. The mux task
// may still hold the other ref during the unsupported
// reply's settle window — fall back to a clone in that
// race (rare; the reply path drops its ref before we
// reach here in practice).
let bytes = Arc::try_unwrap(data).unwrap_or_else(|a| (*a).clone());
(sid, None, Some(bytes))
// Replay the buffered ClientHello on the first tunnel_loop
// iteration. `Bytes::clone()` is a cheap Arc bump — no
// copy of the 64 KB buffer.
(sid, None, Some(data))
}
},
None => (connect_plain(host, port, mux).await?, None, None),
@@ -1107,7 +1148,7 @@ async fn connect_plain(host: &str, port: u16, mux: &Arc<TunnelMux>) -> std::io::
async fn connect_with_initial_data(
host: &str,
port: u16,
data: Arc<Vec<u8>>,
data: Bytes,
mux: &Arc<TunnelMux>,
) -> std::io::Result<ConnectDataOutcome> {
let (reply_tx, reply_rx) = oneshot::channel();
@@ -1212,10 +1253,30 @@ async fn tunnel_loop(
sock: &mut TcpStream,
sid: &str,
mux: &Arc<TunnelMux>,
mut pending_client_data: Option<Vec<u8>>,
mut pending_client_data: Option<Bytes>,
) -> std::io::Result<()> {
let (mut reader, mut writer) = sock.split();
let mut buf = vec![0u8; 65536];
// `BytesMut` + `read_buf` + a per-read decision between
// `split().freeze()` (zero-copy) and `copy_from_slice` + `clear`
// (right-sized copy, buffer reused).
//
// Why the split decision: `bytes` 1.x refcounts the *whole*
// backing allocation, so a frozen `Bytes` from a partial read
// pins all `READ_CHUNK` bytes until it drops. Under semaphore
// saturation or reply timeouts, dozens of small TLS records or
// HTTP/2 frames can each retain ~64 KB instead of their actual
// payload size — order-of-magnitude memory regression on
// constrained targets (router builds with 64 MB RAM).
//
// Threshold: at ≥ half-buffer the saved memcpy outweighs the
// wasted slack, and these reads are typically streaming bulk
// transfer where the `Bytes` flushes through the mux quickly.
// Below that, copy out and `clear()` so the same allocation
// serves the next read — equivalent memory profile to the old
// `vec![0u8; 65536]` + `to_vec()` code on small-read workloads.
const READ_CHUNK: usize = 65536;
const ZERO_COPY_THRESHOLD: usize = READ_CHUNK / 2;
let mut buf = BytesMut::with_capacity(READ_CHUNK);
let mut consecutive_empty = 0u32;
loop {
@@ -1254,11 +1315,28 @@ async fn tunnel_loop(
(true, _) => Duration::from_secs(30),
};
match tokio::time::timeout(read_timeout, reader.read(&mut buf)).await {
buf.reserve(READ_CHUNK);
match tokio::time::timeout(read_timeout, reader.read_buf(&mut buf)).await {
Ok(Ok(0)) => break,
Ok(Ok(n)) => {
consecutive_empty = 0;
Some(buf[..n].to_vec())
if n >= ZERO_COPY_THRESHOLD {
// Big read: split off the filled region. The
// frozen `Bytes` is at-least-half-full, so the
// saved 64 KB memcpy outweighs the brief
// retention until the mux drains.
Some(buf.split().freeze())
} else {
// Small read: copy out a payload-sized `Bytes`
// and `clear()` so the buffer is reused on the
// next iter (no `reserve` allocation needed
// because the alloc stays uniquely owned).
// Bounds retention to actual data even when
// the mux is backpressured.
let owned = Bytes::copy_from_slice(&buf[..n]);
buf.clear();
Some(owned)
}
}
Ok(Err(_)) => break,
Err(_) => None,
@@ -1275,7 +1353,7 @@ async fn tunnel_loop(
continue;
}
let data = client_data.unwrap_or_default();
let data = client_data.unwrap_or_else(Bytes::new);
let was_empty_poll = data.is_empty();
let (reply_tx, reply_rx) = oneshot::channel();
@@ -1664,7 +1742,7 @@ mod tests {
let mut server_side = accept.await.unwrap();
let (mux, mut rx) = mux_for_test();
let pending = Some(b"CLIENTHELLO".to_vec());
let pending = Some(Bytes::from_static(b"CLIENTHELLO"));
let loop_handle = tokio::spawn({
let mux = mux.clone();
@@ -1907,6 +1985,199 @@ mod tests {
);
}
#[test]
fn should_fire_first_op_never_fires() {
// Empty accumulator: even a single op larger than the payload cap
// must not fire — there's nothing to fire yet, and the op gets
// added (it will simply be the only op in the next batch).
assert!(!should_fire(0, 0, 0));
assert!(!should_fire(0, 0, MAX_BATCH_PAYLOAD_BYTES + 1_000_000));
}
#[test]
fn should_fire_at_max_ops_threshold() {
// 49 already-queued ops + 50th: still fits (boundary is `>=`).
assert!(!should_fire(MAX_BATCH_OPS - 1, 0, 100));
// 50 already-queued ops + 51st: must fire.
assert!(should_fire(MAX_BATCH_OPS, 0, 100));
// Well past the cap: must fire.
assert!(should_fire(MAX_BATCH_OPS + 5, 0, 100));
}
#[test]
fn should_fire_when_payload_would_exceed_cap() {
// Exactly at the cap is fine — strict `>`.
assert!(!should_fire(
10,
MAX_BATCH_PAYLOAD_BYTES - 100,
100,
));
// One byte over: fire.
assert!(should_fire(
10,
MAX_BATCH_PAYLOAD_BYTES - 100,
101,
));
// Sum overflow well past the cap: fire.
assert!(should_fire(
10,
MAX_BATCH_PAYLOAD_BYTES,
1,
));
}
/// Reply indices must point at the slot the op occupies *within its
/// batch*. Pre-flush ops are 0..N-1 in batch A; post-flush ops
/// restart at 0 in batch B. If this regresses, `fire_batch`'s
/// `batch_resp.r.get(idx)` lookup hands the wrong response (or
/// `None`) to the wrong session — silent data corruption that
/// the encode-layer tests can't catch.
#[tokio::test]
async fn batch_accum_reindexes_after_flush() {
// Stand-alone helper that mirrors `push_or_fire`'s push step
// without the fire_batch call — lets us simulate a flush with
// `mem::take` and assert the post-flush indexing without
// mocking the whole tunnel_request stack.
fn push_no_fire(
accum: &mut BatchAccum,
op: PendingOp,
op_bytes: usize,
reply: BatchedReply,
) {
let idx = accum.pending_ops.len();
accum.pending_ops.push(op);
accum.data_replies.push((idx, reply));
accum.payload_bytes += op_bytes;
}
let mk_op = |sid: &str| PendingOp {
op: "data",
sid: Some(sid.into()),
host: None,
port: None,
data: Some(Bytes::from_static(b"x")),
encode_empty: false,
};
let mk_reply = || oneshot::channel::<Result<(TunnelResponse, String), String>>().0;
let mut accum = BatchAccum::new();
// Batch A: 3 ops at indices 0, 1, 2.
push_no_fire(&mut accum, mk_op("a0"), 4, mk_reply());
push_no_fire(&mut accum, mk_op("a1"), 4, mk_reply());
push_no_fire(&mut accum, mk_op("a2"), 4, mk_reply());
assert_eq!(accum.pending_ops.len(), 3);
assert_eq!(
accum.data_replies.iter().map(|(i, _)| *i).collect::<Vec<_>>(),
vec![0, 1, 2],
);
assert_eq!(accum.payload_bytes, 12);
// Simulate the flush: take the queued state and reset the byte
// counter (matches what `push_or_fire` does after `fire_batch`).
let _flushed_ops = std::mem::take(&mut accum.pending_ops);
let _flushed_replies = std::mem::take(&mut accum.data_replies);
accum.payload_bytes = 0;
// Batch B: 2 ops, indices restart at 0.
push_no_fire(&mut accum, mk_op("b0"), 4, mk_reply());
push_no_fire(&mut accum, mk_op("b1"), 4, mk_reply());
assert_eq!(accum.pending_ops.len(), 2);
assert_eq!(
accum.data_replies.iter().map(|(i, _)| *i).collect::<Vec<_>>(),
vec![0, 1],
"post-flush indices must restart at 0 — otherwise fire_batch's \
batch_resp.r.get(idx) returns None and every session in the \
second batch sees a missing-response error"
);
assert_eq!(accum.payload_bytes, 8);
}
#[test]
fn encode_pending_data_op_with_payload_emits_base64() {
let op = PendingOp {
op: "data",
sid: Some("sid-1".into()),
host: None,
port: None,
data: Some(Bytes::from_static(b"hello")),
encode_empty: false,
};
let b = encode_pending(op);
assert_eq!(b.op, "data");
assert_eq!(b.sid.as_deref(), Some("sid-1"));
assert_eq!(b.d.as_deref(), Some(B64.encode(b"hello").as_str()));
}
#[test]
fn encode_pending_omits_d_for_empty_polls_and_close() {
// Empty-poll Data: mux_loop converts empty Bytes to data: None.
let empty_poll = PendingOp {
op: "data",
sid: Some("sid-2".into()),
host: None,
port: None,
data: None,
encode_empty: false,
};
assert!(encode_pending(empty_poll).d.is_none());
// UDP poll with no payload: same shape.
let udp_poll = PendingOp {
op: "udp_data",
sid: Some("sid-3".into()),
host: None,
port: None,
data: None,
encode_empty: false,
};
assert!(encode_pending(udp_poll).d.is_none());
// Close has no data and no reply — `d` must stay omitted.
let close = PendingOp {
op: "close",
sid: Some("sid-4".into()),
host: None,
port: None,
data: None,
encode_empty: false,
};
assert!(encode_pending(close).d.is_none());
}
#[test]
fn encode_pending_connect_data_emits_empty_string_when_data_is_empty() {
// Defensive: ConnectData's wire contract is that `d` is always
// present (its presence is the signal that the caller is opting
// into the bundled-first-bytes flow). If an empty Bytes ever
// reaches the encoder, we must serialize `d: ""` not omit it.
let op = PendingOp {
op: "connect_data",
sid: None,
host: Some("example.com".into()),
port: Some(443),
data: Some(Bytes::new()),
encode_empty: true,
};
let b = encode_pending(op);
assert_eq!(b.op, "connect_data");
assert_eq!(b.d.as_deref(), Some(""));
}
#[test]
fn encode_pending_connect_data_with_payload_encodes_normally() {
let op = PendingOp {
op: "connect_data",
sid: None,
host: Some("example.com".into()),
port: Some(443),
data: Some(Bytes::from_static(b"\x16\x03\x01")), // ClientHello prefix
encode_empty: true,
};
let b = encode_pending(op);
assert_eq!(b.d.as_deref(), Some(B64.encode(b"\x16\x03\x01").as_str()));
}
#[test]
fn preread_counters_track_each_outcome() {
let (mux, _rx) = mux_for_test();