feat(tunnel): save one RTT per new HTTPS flow via connect_data op

This commit is contained in:
dazzling-no-more
2026-04-25 01:38:30 +04:00
parent 5bb26a4961
commit 0a58943433
4 changed files with 908 additions and 66 deletions
+10 -1
View File
@@ -56,6 +56,12 @@ function _doTunnel(req) {
payload.host = req.h; payload.host = req.h;
payload.port = req.p; payload.port = req.p;
break; break;
case "connect_data":
payload.op = "connect_data";
payload.host = req.h;
payload.port = req.p;
if (req.d) payload.data = req.d;
break;
case "data": case "data":
payload.op = "data"; payload.op = "data";
payload.sid = req.sid; payload.sid = req.sid;
@@ -66,7 +72,10 @@ function _doTunnel(req) {
payload.sid = req.sid; payload.sid = req.sid;
break; break;
default: default:
return _json({ e: "unknown tunnel op: " + req.t }); // Structured `code` lets the Rust client detect version skew
// without substring-matching the error text. Must match
// CODE_UNSUPPORTED_OP in tunnel_client.rs and tunnel-node/src/main.rs.
return _json({ e: "unknown tunnel op: " + req.t, code: "UNSUPPORTED_OP" });
} }
var resp = UrlFetchApp.fetch(TUNNEL_SERVER_URL + "/tunnel", { var resp = UrlFetchApp.fetch(TUNNEL_SERVER_URL + "/tunnel", {
+5
View File
@@ -180,6 +180,11 @@ pub struct TunnelResponse {
pub eof: Option<bool>, pub eof: Option<bool>,
#[serde(default)] #[serde(default)]
pub e: Option<String>, pub e: Option<String>,
/// Structured error code from the tunnel-node (e.g. `UNSUPPORTED_OP`).
/// `None` for legacy tunnel-nodes; clients should fall back to parsing
/// `e` only when this is `None` and compatibility is needed.
#[serde(default)]
pub code: Option<String>,
} }
/// A single op in a batch tunnel request. /// A single op in a batch tunnel request.
+594 -47
View File
@@ -6,12 +6,13 @@
//! 30 in-flight requests — matching the per-account Apps Script limit. //! 30 in-flight requests — matching the per-account Apps Script limit.
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::{Duration, Instant};
use base64::engine::general_purpose::STANDARD as B64; use base64::engine::general_purpose::STANDARD as B64;
use base64::Engine; use base64::Engine;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot, Semaphore}; use tokio::sync::{mpsc, oneshot, Semaphore};
@@ -41,6 +42,25 @@ const BATCH_TIMEOUT: Duration = Duration::from_secs(30);
/// blocking indefinitely. /// blocking indefinitely.
const REPLY_TIMEOUT: Duration = Duration::from_secs(35); const REPLY_TIMEOUT: Duration = Duration::from_secs(35);
/// How long we'll briefly hold the client socket after the local
/// CONNECT/SOCKS5 handshake, waiting for the client's first bytes (the
/// TLS ClientHello for HTTPS). Bundling those bytes with the tunnel-node
/// connect saves one Apps Script round-trip per new flow.
const CLIENT_FIRST_DATA_WAIT: Duration = Duration::from_millis(50);
/// Structured error code the tunnel-node returns when it doesn't know the
/// op (version mismatch). Must match `tunnel-node/src/main.rs`.
const CODE_UNSUPPORTED_OP: &str = "UNSUPPORTED_OP";
/// Ports where the *server* speaks first (SMTP banner, SSH identification,
/// POP3/IMAP greeting, FTP banner). On these, waiting for client bytes
/// gains nothing and just adds handshake latency — skip the pre-read.
/// HTTP on 80 also qualifies because a naive HTTP client may not flush
/// the request line immediately after the CONNECT reply.
fn is_server_speaks_first(port: u16) -> bool {
matches!(port, 21 | 22 | 25 | 80 | 110 | 143 | 587)
}
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Multiplexer // Multiplexer
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -51,6 +71,14 @@ enum MuxMsg {
port: u16, port: u16,
reply: oneshot::Sender<Result<TunnelResponse, String>>, reply: oneshot::Sender<Result<TunnelResponse, String>>,
}, },
ConnectData {
host: String,
port: u16,
// Arc so the caller can hand the buffer to the mux AND keep a ref
// for the fallback path without an extra 64 KB copy per session.
data: Arc<Vec<u8>>,
reply: oneshot::Sender<Result<TunnelResponse, String>>,
},
Data { Data {
sid: String, sid: String,
data: Vec<u8>, data: Vec<u8>,
@@ -63,6 +91,27 @@ enum MuxMsg {
pub struct TunnelMux { pub struct TunnelMux {
tx: mpsc::Sender<MuxMsg>, tx: mpsc::Sender<MuxMsg>,
/// Set to `true` after the first time the tunnel-node rejects
/// `connect_data` as unsupported. Subsequent sessions skip the
/// optimistic path entirely and go straight to plain connect + data.
connect_data_unsupported: Arc<AtomicBool>,
/// Pre-read observability. Lets an operator see whether the 50 ms
/// wait-for-first-bytes is pulling its weight:
/// * `preread_win` — client sent bytes in time, bundled with connect
/// * `preread_loss` — timed out empty; paid 50 ms for nothing
/// * `preread_skip_port` — port was server-speaks-first; skipped wait
/// * `preread_skip_unsupported` — tunnel-node said no; skipped wait
/// A rolling sum of win-time (µs) drives a `mean_win_time` readout so
/// you can tune `CLIENT_FIRST_DATA_WAIT` against real client flush
/// timing. A summary line is logged every 100 preread events.
preread_win: AtomicU64,
preread_loss: AtomicU64,
preread_skip_port: AtomicU64,
preread_skip_unsupported: AtomicU64,
preread_win_total_us: AtomicU64,
/// Separate monotonic counter used only to trigger the summary log
/// (avoids a race where two threads both see `total % 100 == 0`).
preread_total_events: AtomicU64,
} }
impl TunnelMux { impl TunnelMux {
@@ -75,18 +124,88 @@ impl TunnelMux {
); );
let (tx, rx) = mpsc::channel(512); let (tx, rx) = mpsc::channel(512);
tokio::spawn(mux_loop(rx, fronter)); tokio::spawn(mux_loop(rx, fronter));
Arc::new(Self { tx }) Arc::new(Self {
tx,
connect_data_unsupported: Arc::new(AtomicBool::new(false)),
preread_win: AtomicU64::new(0),
preread_loss: AtomicU64::new(0),
preread_skip_port: AtomicU64::new(0),
preread_skip_unsupported: AtomicU64::new(0),
preread_win_total_us: AtomicU64::new(0),
preread_total_events: AtomicU64::new(0),
})
} }
async fn send(&self, msg: MuxMsg) { async fn send(&self, msg: MuxMsg) {
let _ = self.tx.send(msg).await; let _ = self.tx.send(msg).await;
} }
fn connect_data_unsupported(&self) -> bool {
self.connect_data_unsupported.load(Ordering::Relaxed)
}
fn mark_connect_data_unsupported(&self) {
if !self.connect_data_unsupported.swap(true, Ordering::Relaxed) {
tracing::warn!(
"tunnel-node doesn't support connect_data (pre-v1.x); falling back to plain connect + data for all future sessions"
);
}
}
fn record_preread_win(&self, port: u16, elapsed: Duration) {
self.preread_win.fetch_add(1, Ordering::Relaxed);
self.preread_win_total_us
.fetch_add(elapsed.as_micros() as u64, Ordering::Relaxed);
tracing::debug!("preread win: port={} took={:?}", port, elapsed);
self.maybe_log_preread_summary();
}
fn record_preread_loss(&self, port: u16) {
self.preread_loss.fetch_add(1, Ordering::Relaxed);
tracing::debug!("preread loss: port={} (empty within {:?})", port, CLIENT_FIRST_DATA_WAIT);
self.maybe_log_preread_summary();
}
fn record_preread_skip_port(&self, port: u16) {
self.preread_skip_port.fetch_add(1, Ordering::Relaxed);
tracing::debug!("preread skip: port={} (server-speaks-first)", port);
self.maybe_log_preread_summary();
}
fn record_preread_skip_unsupported(&self, port: u16) {
self.preread_skip_unsupported.fetch_add(1, Ordering::Relaxed);
tracing::debug!("preread skip: port={} (connect_data unsupported)", port);
self.maybe_log_preread_summary();
}
/// Emit an aggregate summary exactly once per 100 preread events.
/// Using a dedicated counter for the trigger avoids a race where two
/// threads both observe the win/loss/skip totals summing to a
/// multiple of 100 — here, exactly one thread gets the boundary.
fn maybe_log_preread_summary(&self) {
let new_count = self.preread_total_events.fetch_add(1, Ordering::Relaxed) + 1;
if new_count % 100 != 0 {
return;
}
let win = self.preread_win.load(Ordering::Relaxed);
let loss = self.preread_loss.load(Ordering::Relaxed);
let skip_port = self.preread_skip_port.load(Ordering::Relaxed);
let skip_unsup = self.preread_skip_unsupported.load(Ordering::Relaxed);
let total_us = self.preread_win_total_us.load(Ordering::Relaxed);
let mean_us = if win > 0 { total_us / win } else { 0 };
tracing::info!(
"connect_data preread: {} win / {} loss / {} skip(port) / {} skip(unsup), mean win time {}µs (ceiling {}µs)",
win,
loss,
skip_port,
skip_unsup,
mean_us,
CLIENT_FIRST_DATA_WAIT.as_micros(),
);
}
} }
async fn mux_loop( async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>) {
mut rx: mpsc::Receiver<MuxMsg>,
fronter: Arc<DomainFronter>,
) {
// One semaphore per deployment ID, each allowing 30 concurrent requests. // One semaphore per deployment ID, each allowing 30 concurrent requests.
let sems: Arc<HashMap<String, Arc<Semaphore>>> = Arc::new( let sems: Arc<HashMap<String, Arc<Semaphore>>> = Arc::new(
fronter fronter
@@ -107,7 +226,7 @@ async fn mux_loop(
msgs.push(msg); msgs.push(msg);
} }
// Split: connects go parallel, data/close get batched. // Split: plain connects go parallel, data-bearing ops get batched.
let mut data_ops: Vec<BatchOp> = Vec::new(); let mut data_ops: Vec<BatchOp> = Vec::new();
let mut data_replies: Vec<(usize, oneshot::Sender<Result<TunnelResponse, String>>)> = let mut data_replies: Vec<(usize, oneshot::Sender<Result<TunnelResponse, String>>)> =
Vec::new(); Vec::new();
@@ -128,6 +247,35 @@ async fn mux_loop(
} }
}); });
} }
MuxMsg::ConnectData { host, port, data, reply } => {
let encoded = Some(B64.encode(data.as_slice()));
let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0);
if !data_ops.is_empty()
&& (data_ops.len() >= MAX_BATCH_OPS
|| batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES)
{
fire_batch(
&sems,
&fronter,
std::mem::take(&mut data_ops),
std::mem::take(&mut data_replies),
)
.await;
batch_payload_bytes = 0;
}
let idx = data_ops.len();
data_ops.push(BatchOp {
op: "connect_data".into(),
sid: None,
host: Some(host),
port: Some(port),
d: encoded,
});
data_replies.push((idx, reply));
batch_payload_bytes += op_bytes;
}
MuxMsg::Data { sid, data, reply } => { MuxMsg::Data { sid, data, reply } => {
let encoded = if data.is_empty() { let encoded = if data.is_empty() {
None None
@@ -219,7 +367,12 @@ async fn fire_batch(
f.tunnel_batch_request_to(&script_id, &data_ops), f.tunnel_batch_request_to(&script_id, &data_ops),
) )
.await; .await;
tracing::info!("batch: {} ops → {}, rtt={:?}", n_ops, &script_id[..script_id.len().min(8)], t0.elapsed()); tracing::info!(
"batch: {} ops → {}, rtt={:?}",
n_ops,
&script_id[..script_id.len().min(8)],
t0.elapsed()
);
match result { match result {
Ok(Ok(batch_resp)) => { Ok(Ok(batch_resp)) => {
@@ -258,6 +411,95 @@ pub async fn tunnel_connection(
port: u16, port: u16,
mux: &Arc<TunnelMux>, mux: &Arc<TunnelMux>,
) -> std::io::Result<()> { ) -> std::io::Result<()> {
// Only try the bundled connect+data optimization when it's likely to
// pay off — client-speaks-first protocols (TLS on 443 et al.) — and
// only if the tunnel-node has already accepted `connect_data` at least
// once this process lifetime (or we haven't tried yet). Check the
// fallback cache first so `skip(unsup)` shadows `skip(port)` in the
// metrics once the feature is disabled process-wide.
let initial_data = if mux.connect_data_unsupported() {
mux.record_preread_skip_unsupported(port);
None
} else if is_server_speaks_first(port) {
mux.record_preread_skip_port(port);
None
} else {
let mut buf = vec![0u8; 65536];
let t0 = Instant::now();
match tokio::time::timeout(CLIENT_FIRST_DATA_WAIT, sock.read(&mut buf)).await {
Ok(Ok(0)) => return Ok(()),
Ok(Ok(n)) => {
mux.record_preread_win(port, t0.elapsed());
buf.truncate(n);
Some(Arc::new(buf))
}
Ok(Err(e)) => return Err(e),
Err(_) => {
mux.record_preread_loss(port);
None
}
}
};
let (sid, first_resp, pending_client_data) = match initial_data {
Some(data) => match connect_with_initial_data(host, port, data.clone(), mux).await? {
ConnectDataOutcome::Opened { sid, response } => (sid, Some(response), None),
ConnectDataOutcome::Unsupported => {
mux.mark_connect_data_unsupported();
let sid = connect_plain(host, port, mux).await?;
// Recover the buffered ClientHello from the Arc so the
// first tunnel_loop iteration can replay it. The mux task
// may still hold the other ref during the unsupported
// reply's settle window — fall back to a clone in that
// race (rare; the reply path drops its ref before we
// reach here in practice).
let bytes = Arc::try_unwrap(data).unwrap_or_else(|a| (*a).clone());
(sid, None, Some(bytes))
}
},
None => (connect_plain(host, port, mux).await?, None, None),
};
tracing::info!("tunnel session {} opened for {}:{}", sid, host, port);
// Run the first-response write + tunnel_loop inside an async block so
// any io-error propagates via `?` without bypassing the Close below.
// We deliberately don't use a Drop guard for Close: a Drop impl can't
// .await cleanly, and tokio::spawn from inside Drop is unreliable
// during runtime shutdown. The explicit send below covers every
// non-panic path; a panic during tunnel_loop would leak the session
// on the tunnel-node until its 5-minute idle reaper runs.
let result = async {
if let Some(resp) = first_resp {
match write_tunnel_response(&mut sock, &resp).await? {
WriteOutcome::Wrote | WriteOutcome::NoData => {}
WriteOutcome::BadBase64 => {
tracing::error!("tunnel session {}: bad base64 in connect_data response", sid);
return Ok(());
}
}
if resp.eof.unwrap_or(false) {
return Ok(());
}
}
tunnel_loop(&mut sock, &sid, mux, pending_client_data).await
}
.await;
mux.send(MuxMsg::Close { sid: sid.clone() }).await;
tracing::info!("tunnel session {} closed for {}:{}", sid, host, port);
result
}
enum ConnectDataOutcome {
Opened {
sid: String,
response: TunnelResponse,
},
Unsupported,
}
async fn connect_plain(host: &str, port: u16, mux: &Arc<TunnelMux>) -> std::io::Result<String> {
let (reply_tx, reply_rx) = oneshot::channel(); let (reply_tx, reply_rx) = oneshot::channel();
mux.send(MuxMsg::Connect { mux.send(MuxMsg::Connect {
host: host.to_string(), host: host.to_string(),
@@ -266,7 +508,7 @@ pub async fn tunnel_connection(
}) })
.await; .await;
let sid = match reply_rx.await { match reply_rx.await {
Ok(Ok(resp)) => { Ok(Ok(resp)) => {
if let Some(ref e) = resp.e { if let Some(ref e) = resp.e {
tracing::error!("tunnel connect error for {}:{}: {}", host, port, e); tracing::error!("tunnel connect error for {}:{}: {}", host, port, e);
@@ -277,10 +519,45 @@ pub async fn tunnel_connection(
} }
resp.sid.ok_or_else(|| { resp.sid.ok_or_else(|| {
std::io::Error::new(std::io::ErrorKind::Other, "tunnel connect: no session id") std::io::Error::new(std::io::ErrorKind::Other, "tunnel connect: no session id")
})? })
} }
Ok(Err(e)) => { Ok(Err(e)) => {
tracing::error!("tunnel connect error for {}:{}: {}", host, port, e); tracing::error!("tunnel connect error for {}:{}: {}", host, port, e);
Err(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
e,
))
}
Err(_) => Err(std::io::Error::new(
std::io::ErrorKind::Other,
"mux channel closed",
)),
}
}
async fn connect_with_initial_data(
host: &str,
port: u16,
data: Arc<Vec<u8>>,
mux: &Arc<TunnelMux>,
) -> std::io::Result<ConnectDataOutcome> {
let (reply_tx, reply_rx) = oneshot::channel();
mux.send(MuxMsg::ConnectData {
host: host.to_string(),
port,
data,
reply: reply_tx,
})
.await;
let resp = match reply_rx.await {
Ok(Ok(resp)) => resp,
Ok(Err(e)) => {
if is_connect_data_unsupported_error_str(&e) {
tracing::debug!("connect_data unsupported for {}:{}: {}", host, port, e);
return Ok(ConnectDataOutcome::Unsupported);
}
tracing::error!("tunnel connect_data error for {}:{}: {}", host, port, e);
return Err(std::io::Error::new( return Err(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused, std::io::ErrorKind::ConnectionRefused,
e, e,
@@ -294,38 +571,96 @@ pub async fn tunnel_connection(
} }
}; };
tracing::info!("tunnel session {} opened for {}:{}", sid, host, port); if is_connect_data_unsupported_response(&resp) {
let result = tunnel_loop(&mut sock, &sid, mux).await; tracing::debug!(
mux.send(MuxMsg::Close { sid: sid.clone() }).await; "connect_data unsupported for {}:{}: {:?}",
tracing::info!("tunnel session {} closed for {}:{}", sid, host, port); host,
result port,
resp.e
);
return Ok(ConnectDataOutcome::Unsupported);
}
if let Some(ref e) = resp.e {
tracing::error!("tunnel connect_data error for {}:{}: {}", host, port, e);
return Err(std::io::Error::new(
std::io::ErrorKind::ConnectionRefused,
e.clone(),
));
}
let Some(sid) = resp.sid.clone() else {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"tunnel connect_data: no session id",
));
};
Ok(ConnectDataOutcome::Opened { sid, response: resp })
}
/// Decide whether a response indicates the tunnel-node (or apps_script
/// layer in front of it) didn't recognize `connect_data`.
///
/// Primary signal: the structured `code` field (`UNSUPPORTED_OP`), emitted
/// by any tunnel-node or apps_script deployment that has this change.
/// Fallback signal (for legacy deployments, pre-connect_data): substring
/// match on the stable error string. The string-match is a one-way
/// compatibility hatch — newer deployments set `code` so future refactors
/// of the error text won't silently break detection.
///
/// Two error shapes are possible on the legacy path:
/// * tunnel-node's single-op/batch handler: `"unknown op: connect_data"`
/// * apps_script's `_doTunnel` default branch: `"unknown tunnel op: connect_data"`
///
/// Apps_script and tunnel-node ship on independent cadences, so it is
/// realistic for a user to upgrade one but not the other — detection has
/// to cover both shapes or the feature hard-fails on version skew.
fn is_connect_data_unsupported_response(resp: &TunnelResponse) -> bool {
if resp.code.as_deref() == Some(CODE_UNSUPPORTED_OP) {
return true;
}
resp.e
.as_deref()
.map(is_connect_data_unsupported_error_str)
.unwrap_or(false)
}
fn is_connect_data_unsupported_error_str(e: &str) -> bool {
let e = e.to_ascii_lowercase();
(e.contains("unknown op") || e.contains("unknown tunnel op")) && e.contains("connect_data")
} }
async fn tunnel_loop( async fn tunnel_loop(
sock: &mut TcpStream, sock: &mut TcpStream,
sid: &str, sid: &str,
mux: &Arc<TunnelMux>, mux: &Arc<TunnelMux>,
mut pending_client_data: Option<Vec<u8>>,
) -> std::io::Result<()> { ) -> std::io::Result<()> {
let (mut reader, mut writer) = sock.split(); let (mut reader, mut writer) = sock.split();
let mut buf = vec![0u8; 65536]; let mut buf = vec![0u8; 65536];
let mut consecutive_empty = 0u32; let mut consecutive_empty = 0u32;
loop { loop {
let read_timeout = match consecutive_empty { let client_data = if let Some(data) = pending_client_data.take() {
0 => Duration::from_millis(20), Some(data)
1 => Duration::from_millis(80), } else {
2 => Duration::from_millis(200), let read_timeout = match consecutive_empty {
_ => Duration::from_secs(30), 0 => Duration::from_millis(20),
}; 1 => Duration::from_millis(80),
2 => Duration::from_millis(200),
_ => Duration::from_secs(30),
};
let client_data = match tokio::time::timeout(read_timeout, reader.read(&mut buf)).await { match tokio::time::timeout(read_timeout, reader.read(&mut buf)).await {
Ok(Ok(0)) => break, Ok(Ok(0)) => break,
Ok(Ok(n)) => { Ok(Ok(n)) => {
consecutive_empty = 0; consecutive_empty = 0;
Some(buf[..n].to_vec()) Some(buf[..n].to_vec())
}
Ok(Err(_)) => break,
Err(_) => None,
} }
Ok(Err(_)) => break,
Err(_) => None,
}; };
if client_data.is_none() && consecutive_empty > 3 { if client_data.is_none() && consecutive_empty > 3 {
@@ -364,25 +699,15 @@ async fn tunnel_loop(
break; break;
} }
let got_data = if let Some(ref d) = resp.d { let got_data = match write_tunnel_response(&mut writer, &resp).await? {
if !d.is_empty() { WriteOutcome::Wrote => true,
match B64.decode(d) { WriteOutcome::NoData => false,
Ok(bytes) if !bytes.is_empty() => { WriteOutcome::BadBase64 => {
writer.write_all(&bytes).await?; // Tunnel-node gave us garbage; tear the session down but
writer.flush().await?; // do NOT propagate as an io error — the caller's Close
true // guard will clean up on the tunnel-node side.
} break;
Err(e) => {
tracing::error!("tunnel bad base64: {}", e);
break;
}
_ => false,
}
} else {
false
} }
} else {
false
}; };
if resp.eof.unwrap_or(false) { if resp.eof.unwrap_or(false) {
@@ -398,3 +723,225 @@ async fn tunnel_loop(
Ok(()) Ok(())
} }
enum WriteOutcome {
Wrote,
NoData,
BadBase64,
}
async fn write_tunnel_response<W>(
writer: &mut W,
resp: &TunnelResponse,
) -> std::io::Result<WriteOutcome>
where
W: AsyncWrite + Unpin,
{
let Some(ref d) = resp.d else {
return Ok(WriteOutcome::NoData);
};
if d.is_empty() {
return Ok(WriteOutcome::NoData);
}
match B64.decode(d) {
Ok(bytes) if !bytes.is_empty() => {
writer.write_all(&bytes).await?;
writer.flush().await?;
Ok(WriteOutcome::Wrote)
}
Ok(_) => Ok(WriteOutcome::NoData),
Err(e) => {
tracing::error!("tunnel bad base64: {}", e);
Ok(WriteOutcome::BadBase64)
}
}
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
fn resp_with(code: Option<&str>, e: Option<&str>) -> TunnelResponse {
TunnelResponse {
sid: None,
d: None,
eof: None,
e: e.map(str::to_string),
code: code.map(str::to_string),
}
}
#[test]
fn unsupported_detection_via_structured_code() {
assert!(is_connect_data_unsupported_response(&resp_with(Some("UNSUPPORTED_OP"), None)));
assert!(is_connect_data_unsupported_response(&resp_with(
Some("UNSUPPORTED_OP"),
Some("unknown op: connect_data"),
)));
}
#[test]
fn unsupported_detection_via_legacy_tunnel_node_string() {
// Pre-change tunnel-node: no code field, bare "unknown op: ...".
assert!(is_connect_data_unsupported_response(&resp_with(
None, Some("unknown op: connect_data"),
)));
assert!(is_connect_data_unsupported_response(&resp_with(
None, Some("Unknown Op: CONNECT_DATA"),
)));
}
#[test]
fn unsupported_detection_via_legacy_apps_script_string() {
// Pre-change apps_script: default branch emits "unknown tunnel op: ...".
// This is the realistic skew case — user upgrades tunnel-node + client
// binary but hasn't redeployed the Apps Script yet.
assert!(is_connect_data_unsupported_response(&resp_with(
None, Some("unknown tunnel op: connect_data"),
)));
}
#[test]
fn unsupported_detection_rejects_unrelated_errors() {
assert!(!is_connect_data_unsupported_response(&resp_with(
None, Some("connect failed: refused"),
)));
assert!(!is_connect_data_unsupported_response(&resp_with(None, Some("bad base64"))));
assert!(!is_connect_data_unsupported_response(&resp_with(None, None)));
// "connect_data" alone (without "unknown op") shouldn't trigger.
assert!(!is_connect_data_unsupported_response(&resp_with(
None, Some("connect_data: bad port"),
)));
}
#[test]
fn server_speaks_first_covers_common_protocols() {
for p in [21u16, 22, 25, 80, 110, 143, 587] {
assert!(is_server_speaks_first(p), "port {} should be server-first", p);
}
for p in [443u16, 8443, 853, 993, 1234] {
assert!(!is_server_speaks_first(p), "port {} should NOT be server-first", p);
}
}
/// Build a TunnelMux whose send channel is exposed to the test rather
/// than wired to a real DomainFronter. Lets tests assert what messages
/// the client would emit without needing network or apps_script.
fn mux_for_test() -> (Arc<TunnelMux>, mpsc::Receiver<MuxMsg>) {
let (tx, rx) = mpsc::channel(16);
let mux = Arc::new(TunnelMux {
tx,
connect_data_unsupported: Arc::new(AtomicBool::new(false)),
preread_win: AtomicU64::new(0),
preread_loss: AtomicU64::new(0),
preread_skip_port: AtomicU64::new(0),
preread_skip_unsupported: AtomicU64::new(0),
preread_win_total_us: AtomicU64::new(0),
preread_total_events: AtomicU64::new(0),
});
(mux, rx)
}
/// The buffered ClientHello from the pre-read window must reach the
/// tunnel-node as the first `Data` op on the fallback path. If this
/// regresses, every TLS handshake stalls until the 30 s read-timeout
/// fires — catastrophic and silent without a test.
#[tokio::test]
async fn tunnel_loop_replays_pending_client_data_before_reading_socket() {
use tokio::net::TcpListener;
// Set up a loopback pair so tunnel_loop has a real TcpStream to
// work with. We never write to its peer, so tunnel_loop's "read
// from client" branch would block indefinitely — meaning any
// `Data` msg it emits must have come from pending_client_data.
let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
let addr = listener.local_addr().unwrap();
let accept = tokio::spawn(async move { listener.accept().await.unwrap().0 });
let _client = TcpStream::connect(addr).await.unwrap();
let mut server_side = accept.await.unwrap();
let (mux, mut rx) = mux_for_test();
let pending = Some(b"CLIENTHELLO".to_vec());
let loop_handle = tokio::spawn({
let mux = mux.clone();
async move {
tunnel_loop(&mut server_side, "sid-under-test", &mux, pending).await
}
});
// The first message tunnel_loop emits must be Data carrying the
// replayed bytes — NOT whatever it would have read from the socket.
let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
.await
.expect("tunnel_loop did not send a message within 2s")
.expect("mux channel closed unexpectedly");
match msg {
MuxMsg::Data { sid, data, reply } => {
assert_eq!(sid, "sid-under-test");
assert_eq!(&data[..], b"CLIENTHELLO");
// Reply with eof so tunnel_loop unwinds cleanly.
let _ = reply.send(Ok(TunnelResponse {
sid: Some("sid-under-test".into()),
d: None,
eof: Some(true),
e: None,
code: None,
}));
}
other => panic!(
"first mux message was not Data (expected replay); got {:?}",
match other {
MuxMsg::Connect { .. } => "Connect",
MuxMsg::ConnectData { .. } => "ConnectData",
MuxMsg::Data { .. } => unreachable!(),
MuxMsg::Close { .. } => "Close",
}
),
}
let _ = tokio::time::timeout(Duration::from_secs(2), loop_handle)
.await
.expect("tunnel_loop did not exit after eof");
}
/// Once `mark_connect_data_unsupported` is called, future sessions
/// must see the flag — no per-session repeat of the detect-and-fallback
/// cost. If this regresses, every new flow pays an extra round trip
/// against a tunnel-node that will never learn the new op.
#[test]
fn unsupported_cache_is_sticky() {
let (mux, _rx) = mux_for_test();
assert!(!mux.connect_data_unsupported());
mux.mark_connect_data_unsupported();
assert!(mux.connect_data_unsupported());
mux.mark_connect_data_unsupported(); // idempotent
assert!(mux.connect_data_unsupported());
}
#[test]
fn preread_counters_track_each_outcome() {
let (mux, _rx) = mux_for_test();
mux.record_preread_win(443, Duration::from_micros(3_500));
mux.record_preread_win(443, Duration::from_micros(1_500));
mux.record_preread_loss(443);
mux.record_preread_skip_port(80);
mux.record_preread_skip_unsupported(443);
assert_eq!(mux.preread_win.load(Ordering::Relaxed), 2);
assert_eq!(mux.preread_loss.load(Ordering::Relaxed), 1);
assert_eq!(mux.preread_skip_port.load(Ordering::Relaxed), 1);
assert_eq!(mux.preread_skip_unsupported.load(Ordering::Relaxed), 1);
// Two wins summing to 5000 µs.
assert_eq!(mux.preread_win_total_us.load(Ordering::Relaxed), 5_000);
// Five record_* calls, so trigger counter is at 5.
assert_eq!(mux.preread_total_events.load(Ordering::Relaxed), 5);
}
}
+299 -18
View File
@@ -26,6 +26,12 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::TcpStream; use tokio::net::TcpStream;
use tokio::sync::Mutex; use tokio::sync::Mutex;
use tokio::task::JoinSet;
/// Structured error code returned when the tunnel-node receives an op it
/// doesn't recognize. Clients use this (rather than string-matching `e`) to
/// detect a version mismatch and gracefully fall back.
const CODE_UNSUPPORTED_OP: &str = "UNSUPPORTED_OP";
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
// Session // Session
@@ -137,17 +143,25 @@ struct TunnelRequest {
#[serde(default)] data: Option<String>, #[serde(default)] data: Option<String>,
} }
#[derive(Serialize, Clone)] #[derive(Serialize, Clone, Debug)]
struct TunnelResponse { struct TunnelResponse {
#[serde(skip_serializing_if = "Option::is_none")] sid: Option<String>, #[serde(skip_serializing_if = "Option::is_none")] sid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] d: Option<String>, #[serde(skip_serializing_if = "Option::is_none")] d: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] eof: Option<bool>, #[serde(skip_serializing_if = "Option::is_none")] eof: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")] e: Option<String>, #[serde(skip_serializing_if = "Option::is_none")] e: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")] code: Option<String>,
} }
impl TunnelResponse { impl TunnelResponse {
fn error(msg: impl Into<String>) -> Self { fn error(msg: impl Into<String>) -> Self {
Self { sid: None, d: None, eof: None, e: Some(msg.into()) } Self { sid: None, d: None, eof: None, e: Some(msg.into()), code: None }
}
fn unsupported_op(op: &str) -> Self {
Self {
sid: None, d: None, eof: None,
e: Some(format!("unknown op: {}", op)),
code: Some(CODE_UNSUPPORTED_OP.into()),
}
} }
} }
@@ -188,9 +202,12 @@ async fn handle_tunnel(
} }
match req.op.as_str() { match req.op.as_str() {
"connect" => Json(handle_connect(&state, req.host, req.port).await), "connect" => Json(handle_connect(&state, req.host, req.port).await),
"connect_data" => {
Json(handle_connect_data_single(&state, req.host, req.port, req.data).await)
}
"data" => Json(handle_data_single(&state, req.sid, req.data).await), "data" => Json(handle_data_single(&state, req.sid, req.data).await),
"close" => Json(handle_close(&state, req.sid).await), "close" => Json(handle_close(&state, req.sid).await),
other => Json(TunnelResponse::error(format!("unknown op: {}", other))), other => Json(TunnelResponse::unsupported_op(other)),
} }
} }
@@ -238,15 +255,51 @@ async fn handle_batch(
// then do a short sleep to let servers respond, then drain all. // then do a short sleep to let servers respond, then drain all.
// This batches the network round trips on the server side too. // This batches the network round trips on the server side too.
// Phase 1: process connects and writes // Phase 1: process connects and writes.
//
// `connect` and `connect_data` ops each establish a brand-new upstream
// TCP connection which can take up to 10 s (create_session timeout).
// Running them inline head-of-line-blocks every other op in the batch,
// so we dispatch both into a JoinSet and await them concurrently below.
//
// `connect_data` is expected to dominate in practice (new client) but
// we still hit `connect` from older clients or from server-speaks-first
// ports that skip the pre-read — if a slow `connect` landed in the same
// batch as data-bearing ops it could stall everyone.
let mut results: Vec<(usize, TunnelResponse)> = Vec::with_capacity(req.ops.len()); let mut results: Vec<(usize, TunnelResponse)> = Vec::with_capacity(req.ops.len());
let mut data_ops: Vec<(usize, String)> = Vec::new(); // (index, sid) for data ops needing drain let mut data_ops: Vec<(usize, String)> = Vec::new(); // (index, sid) for data ops needing drain
enum NewConn {
Connect(TunnelResponse),
ConnectData(Result<String, TunnelResponse>),
}
let mut new_conn_jobs: JoinSet<(usize, NewConn)> = JoinSet::new();
for (i, op) in req.ops.iter().enumerate() { for (i, op) in req.ops.iter().enumerate() {
match op.op.as_str() { match op.op.as_str() {
"connect" => { "connect" => {
let r = handle_connect(&state, op.host.clone(), op.port).await; let state = state.clone();
results.push((i, r)); let host = op.host.clone();
let port = op.port;
new_conn_jobs.spawn(async move {
(i, NewConn::Connect(handle_connect(&state, host, port).await))
});
}
"connect_data" => {
let state = state.clone();
let host = op.host.clone();
let port = op.port;
let d = op.d.clone();
new_conn_jobs.spawn(async move {
// Drop the returned Arc<SessionInner>: phase 2 below
// holds the sessions-map lock once for the whole batch
// and re-looks up each sid, which is cheap. The Arc
// return is a convenience for the single-op path only.
let r = handle_connect_data_phase1(&state, host, port, d)
.await
.map(|(sid, _inner)| sid);
(i, NewConn::ConnectData(r))
});
} }
"data" => { "data" => {
let sid = match &op.sid { let sid = match &op.sid {
@@ -273,7 +326,9 @@ async fn handle_batch(
data_ops.push((i, sid)); data_ops.push((i, sid));
} else { } else {
drop(sessions); drop(sessions);
results.push((i, TunnelResponse { sid: Some(sid), d: None, eof: Some(true), e: None })); results.push((i, TunnelResponse {
sid: Some(sid), d: None, eof: Some(true), e: None, code: None,
}));
} }
} }
"close" => { "close" => {
@@ -281,7 +336,21 @@ async fn handle_batch(
results.push((i, r)); results.push((i, r));
} }
other => { other => {
results.push((i, TunnelResponse::error(format!("unknown op: {}", other)))); results.push((i, TunnelResponse::unsupported_op(other)));
}
}
}
// Await all concurrent connect / connect_data jobs. For connect_data,
// successful ones join the data-drain set in phase 2; plain connects
// go straight to results because they have no initial data to drain.
while let Some(join) = new_conn_jobs.join_next().await {
match join {
Ok((i, NewConn::Connect(r))) => results.push((i, r)),
Ok((i, NewConn::ConnectData(Ok(sid)))) => data_ops.push((i, sid)),
Ok((i, NewConn::ConnectData(Err(r)))) => results.push((i, r)),
Err(e) => {
tracing::error!("new-connection task panicked: {}", e);
} }
} }
} }
@@ -304,12 +373,12 @@ async fn handle_batch(
results.push((*i, TunnelResponse { results.push((*i, TunnelResponse {
sid: Some(sid.clone()), sid: Some(sid.clone()),
d: if data.is_empty() { None } else { Some(B64.encode(&data)) }, d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
eof: Some(eof), e: None, eof: Some(eof), e: None, code: None,
})); }));
} }
} else { } else {
results.push((*i, TunnelResponse { results.push((*i, TunnelResponse {
sid: Some(sid.clone()), d: None, eof: Some(true), e: None, sid: Some(sid.clone()), d: None, eof: Some(true), e: None, code: None,
})); }));
} }
} }
@@ -325,11 +394,11 @@ async fn handle_batch(
results.push((*i, TunnelResponse { results.push((*i, TunnelResponse {
sid: Some(sid.clone()), sid: Some(sid.clone()),
d: if data.is_empty() { None } else { Some(B64.encode(&data)) }, d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
eof: Some(eof), e: None, eof: Some(eof), e: None, code: None,
})); }));
} else { } else {
results.push((*i, TunnelResponse { results.push((*i, TunnelResponse {
sid: Some(sid.clone()), d: None, eof: Some(true), e: None, sid: Some(sid.clone()), d: None, eof: Some(true), e: None, code: None,
})); }));
} }
} }
@@ -372,14 +441,25 @@ fn decompress_gzip(data: &[u8]) -> Result<Vec<u8>, String> {
// Shared op handlers // Shared op handlers
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
async fn handle_connect(state: &AppState, host: Option<String>, port: Option<u16>) -> TunnelResponse { fn validate_host_port(
host: Option<String>,
port: Option<u16>,
) -> Result<(String, u16), TunnelResponse> {
let host = match host { let host = match host {
Some(h) if !h.is_empty() => h, Some(h) if !h.is_empty() => h,
_ => return TunnelResponse::error("missing host"), _ => return Err(TunnelResponse::error("missing host")),
}; };
let port = match port { let port = match port {
Some(p) if p > 0 => p, Some(p) if p > 0 => p,
_ => return TunnelResponse::error("missing or invalid port"), _ => return Err(TunnelResponse::error("missing or invalid port")),
};
Ok((host, port))
}
async fn handle_connect(state: &AppState, host: Option<String>, port: Option<u16>) -> TunnelResponse {
let (host, port) = match validate_host_port(host, port) {
Ok(v) => v,
Err(r) => return r,
}; };
let session = match create_session(&host, port).await { let session = match create_session(&host, port).await {
Ok(s) => s, Ok(s) => s,
@@ -388,7 +468,82 @@ async fn handle_connect(state: &AppState, host: Option<String>, port: Option<u16
let sid = uuid::Uuid::new_v4().to_string(); let sid = uuid::Uuid::new_v4().to_string();
tracing::info!("session {} -> {}:{}", sid, host, port); tracing::info!("session {} -> {}:{}", sid, host, port);
state.sessions.lock().await.insert(sid.clone(), session); state.sessions.lock().await.insert(sid.clone(), session);
TunnelResponse { sid: Some(sid), d: None, eof: Some(false), e: None } TunnelResponse { sid: Some(sid), d: None, eof: Some(false), e: None, code: None }
}
/// Open a session and write the client's first bytes in one round trip.
/// Returns the new sid plus an `Arc<SessionInner>` so unary callers
/// (`handle_connect_data_single`) can drain the first response without a
/// second sessions-map lookup. The batch caller drops the Arc — it takes
/// a single lock across all drain-bound sessions in phase 2, which is
/// cheaper than the Arc plumbing would be.
async fn handle_connect_data_phase1(
state: &AppState,
host: Option<String>,
port: Option<u16>,
data: Option<String>,
) -> Result<(String, Arc<SessionInner>), TunnelResponse> {
let (host, port) = validate_host_port(host, port)?;
let session = create_session(&host, port)
.await
.map_err(|e| TunnelResponse::error(format!("connect failed: {}", e)))?;
// Any failure below this point must abort the reader task, otherwise
// the newly-opened upstream TCP connection would leak. Keep the
// abort paths explicit rather than burying them in `.map_err`.
if let Some(ref data_b64) = data {
if !data_b64.is_empty() {
let bytes = match B64.decode(data_b64) {
Ok(b) => b,
Err(e) => {
session.reader_handle.abort();
return Err(TunnelResponse::error(format!("bad base64: {}", e)));
}
};
if !bytes.is_empty() {
let mut w = session.inner.writer.lock().await;
if let Err(e) = w.write_all(&bytes).await {
drop(w);
session.reader_handle.abort();
return Err(TunnelResponse::error(format!("write failed: {}", e)));
}
let _ = w.flush().await;
}
}
}
let inner = session.inner.clone();
let sid = uuid::Uuid::new_v4().to_string();
tracing::info!("session {} -> {}:{} (connect_data)", sid, host, port);
state.sessions.lock().await.insert(sid.clone(), session);
Ok((sid, inner))
}
async fn handle_connect_data_single(
state: &AppState,
host: Option<String>,
port: Option<u16>,
data: Option<String>,
) -> TunnelResponse {
let (sid, inner) = match handle_connect_data_phase1(state, host, port, data).await {
Ok(v) => v,
Err(r) => return r,
};
let (data, eof) = wait_and_drain(&inner, Duration::from_secs(5)).await;
if eof {
if let Some(s) = state.sessions.lock().await.remove(&sid) {
s.reader_handle.abort();
tracing::info!("session {} closed by remote", sid);
}
}
TunnelResponse {
sid: Some(sid),
d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
eof: Some(eof),
e: None,
code: None,
}
} }
async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<String>) -> TunnelResponse { async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<String>) -> TunnelResponse {
@@ -428,7 +583,7 @@ async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<
TunnelResponse { TunnelResponse {
sid: Some(sid), sid: Some(sid),
d: if data.is_empty() { None } else { Some(B64.encode(&data)) }, d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
eof: Some(eof), e: None, eof: Some(eof), e: None, code: None,
} }
} }
@@ -441,7 +596,7 @@ async fn handle_close(state: &AppState, sid: Option<String>) -> TunnelResponse {
s.reader_handle.abort(); s.reader_handle.abort();
tracing::info!("session {} closed by client", sid); tracing::info!("session {} closed by client", sid);
} }
TunnelResponse { sid: Some(sid), d: None, eof: Some(true), e: None } TunnelResponse { sid: Some(sid), d: None, eof: Some(true), e: None, code: None }
} }
// --------------------------------------------------------------------------- // ---------------------------------------------------------------------------
@@ -520,3 +675,129 @@ async fn main() {
.await .await
.unwrap(); .unwrap();
} }
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
use tokio::net::TcpListener;
fn fresh_state() -> AppState {
AppState {
sessions: Arc::new(Mutex::new(HashMap::new())),
auth_key: "test-key".into(),
}
}
/// Spin up a one-shot TCP server that echoes everything it reads back
/// with a `"ECHO: "` prefix, then returns the bound port.
async fn start_echo_server() -> u16 {
let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
let port = listener.local_addr().unwrap().port();
tokio::spawn(async move {
if let Ok((mut sock, _)) = listener.accept().await {
let mut buf = [0u8; 1024];
if let Ok(n) = sock.read(&mut buf).await {
let mut out = b"ECHO: ".to_vec();
out.extend_from_slice(&buf[..n]);
let _ = sock.write_all(&out).await;
let _ = sock.flush().await;
}
}
});
port
}
#[tokio::test]
async fn unsupported_op_response_has_structured_code() {
let resp = TunnelResponse::unsupported_op("connect_data");
assert_eq!(resp.code.as_deref(), Some(CODE_UNSUPPORTED_OP));
assert_eq!(resp.e.as_deref(), Some("unknown op: connect_data"));
}
#[tokio::test]
async fn validate_host_port_rejects_empty_and_zero() {
assert!(validate_host_port(None, Some(443)).is_err());
assert!(validate_host_port(Some("".into()), Some(443)).is_err());
assert!(validate_host_port(Some("x".into()), None).is_err());
assert!(validate_host_port(Some("x".into()), Some(0)).is_err());
assert_eq!(
validate_host_port(Some("host".into()), Some(443)).unwrap(),
("host".to_string(), 443),
);
}
#[tokio::test]
async fn connect_data_phase1_writes_initial_data_and_returns_inner() {
let port = start_echo_server().await;
let state = fresh_state();
let (sid, inner) = handle_connect_data_phase1(
&state,
Some("127.0.0.1".into()),
Some(port),
Some(B64.encode(b"hello")),
)
.await
.expect("phase1 should succeed");
// Session was inserted.
assert!(state.sessions.lock().await.contains_key(&sid));
// Echo server sent back "ECHO: hello". Use wait_and_drain on the
// returned Arc — no map re-lookup needed (this is the fix).
let (data, _eof) = wait_and_drain(&inner, Duration::from_secs(2)).await;
assert_eq!(&data[..], b"ECHO: hello");
}
#[tokio::test]
async fn connect_data_single_bundles_connect_and_first_bytes() {
let port = start_echo_server().await;
let state = fresh_state();
let resp = handle_connect_data_single(
&state,
Some("127.0.0.1".into()),
Some(port),
Some(B64.encode(b"world")),
)
.await;
assert!(resp.e.is_none(), "unexpected error: {:?}", resp.e);
assert!(resp.sid.is_some());
let decoded = B64.decode(resp.d.unwrap()).unwrap();
assert_eq!(&decoded[..], b"ECHO: world");
}
#[tokio::test]
async fn connect_data_rejects_missing_host() {
let state = fresh_state();
let resp = handle_connect_data_single(
&state, None, Some(443), Some(B64.encode(b"x")),
).await;
assert!(resp.e.as_deref().unwrap_or("").contains("missing host"));
assert!(state.sessions.lock().await.is_empty());
}
#[tokio::test]
async fn connect_data_rejects_bad_base64_and_does_not_leak_session() {
// Need a live target so we reach the base64-decode step after
// create_session succeeds — otherwise we'd fail earlier.
let port = start_echo_server().await;
let state = fresh_state();
let resp = handle_connect_data_single(
&state,
Some("127.0.0.1".into()),
Some(port),
Some("!!!not base64!!!".into()),
)
.await;
assert!(resp.e.as_deref().unwrap_or("").contains("bad base64"));
// Session should NOT be in the map since phase1 rejected it.
assert!(state.sessions.lock().await.is_empty());
}
}