mirror of
https://github.com/therealaleph/MasterHttpRelayVPN-RUST.git
synced 2026-05-18 23:54:48 +03:00
feat(tunnel): save one RTT per new HTTPS flow via connect_data op
This commit is contained in:
@@ -56,6 +56,12 @@ function _doTunnel(req) {
|
||||
payload.host = req.h;
|
||||
payload.port = req.p;
|
||||
break;
|
||||
case "connect_data":
|
||||
payload.op = "connect_data";
|
||||
payload.host = req.h;
|
||||
payload.port = req.p;
|
||||
if (req.d) payload.data = req.d;
|
||||
break;
|
||||
case "data":
|
||||
payload.op = "data";
|
||||
payload.sid = req.sid;
|
||||
@@ -66,7 +72,10 @@ function _doTunnel(req) {
|
||||
payload.sid = req.sid;
|
||||
break;
|
||||
default:
|
||||
return _json({ e: "unknown tunnel op: " + req.t });
|
||||
// Structured `code` lets the Rust client detect version skew
|
||||
// without substring-matching the error text. Must match
|
||||
// CODE_UNSUPPORTED_OP in tunnel_client.rs and tunnel-node/src/main.rs.
|
||||
return _json({ e: "unknown tunnel op: " + req.t, code: "UNSUPPORTED_OP" });
|
||||
}
|
||||
|
||||
var resp = UrlFetchApp.fetch(TUNNEL_SERVER_URL + "/tunnel", {
|
||||
|
||||
@@ -180,6 +180,11 @@ pub struct TunnelResponse {
|
||||
pub eof: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub e: Option<String>,
|
||||
/// Structured error code from the tunnel-node (e.g. `UNSUPPORTED_OP`).
|
||||
/// `None` for legacy tunnel-nodes; clients should fall back to parsing
|
||||
/// `e` only when this is `None` and compatibility is needed.
|
||||
#[serde(default)]
|
||||
pub code: Option<String>,
|
||||
}
|
||||
|
||||
/// A single op in a batch tunnel request.
|
||||
|
||||
+594
-47
@@ -6,12 +6,13 @@
|
||||
//! 30 in-flight requests — matching the per-account Apps Script limit.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use base64::engine::general_purpose::STANDARD as B64;
|
||||
use base64::Engine;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::{mpsc, oneshot, Semaphore};
|
||||
|
||||
@@ -41,6 +42,25 @@ const BATCH_TIMEOUT: Duration = Duration::from_secs(30);
|
||||
/// blocking indefinitely.
|
||||
const REPLY_TIMEOUT: Duration = Duration::from_secs(35);
|
||||
|
||||
/// How long we'll briefly hold the client socket after the local
|
||||
/// CONNECT/SOCKS5 handshake, waiting for the client's first bytes (the
|
||||
/// TLS ClientHello for HTTPS). Bundling those bytes with the tunnel-node
|
||||
/// connect saves one Apps Script round-trip per new flow.
|
||||
const CLIENT_FIRST_DATA_WAIT: Duration = Duration::from_millis(50);
|
||||
|
||||
/// Structured error code the tunnel-node returns when it doesn't know the
|
||||
/// op (version mismatch). Must match `tunnel-node/src/main.rs`.
|
||||
const CODE_UNSUPPORTED_OP: &str = "UNSUPPORTED_OP";
|
||||
|
||||
/// Ports where the *server* speaks first (SMTP banner, SSH identification,
|
||||
/// POP3/IMAP greeting, FTP banner). On these, waiting for client bytes
|
||||
/// gains nothing and just adds handshake latency — skip the pre-read.
|
||||
/// HTTP on 80 also qualifies because a naive HTTP client may not flush
|
||||
/// the request line immediately after the CONNECT reply.
|
||||
fn is_server_speaks_first(port: u16) -> bool {
|
||||
matches!(port, 21 | 22 | 25 | 80 | 110 | 143 | 587)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Multiplexer
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -51,6 +71,14 @@ enum MuxMsg {
|
||||
port: u16,
|
||||
reply: oneshot::Sender<Result<TunnelResponse, String>>,
|
||||
},
|
||||
ConnectData {
|
||||
host: String,
|
||||
port: u16,
|
||||
// Arc so the caller can hand the buffer to the mux AND keep a ref
|
||||
// for the fallback path without an extra 64 KB copy per session.
|
||||
data: Arc<Vec<u8>>,
|
||||
reply: oneshot::Sender<Result<TunnelResponse, String>>,
|
||||
},
|
||||
Data {
|
||||
sid: String,
|
||||
data: Vec<u8>,
|
||||
@@ -63,6 +91,27 @@ enum MuxMsg {
|
||||
|
||||
pub struct TunnelMux {
|
||||
tx: mpsc::Sender<MuxMsg>,
|
||||
/// Set to `true` after the first time the tunnel-node rejects
|
||||
/// `connect_data` as unsupported. Subsequent sessions skip the
|
||||
/// optimistic path entirely and go straight to plain connect + data.
|
||||
connect_data_unsupported: Arc<AtomicBool>,
|
||||
/// Pre-read observability. Lets an operator see whether the 50 ms
|
||||
/// wait-for-first-bytes is pulling its weight:
|
||||
/// * `preread_win` — client sent bytes in time, bundled with connect
|
||||
/// * `preread_loss` — timed out empty; paid 50 ms for nothing
|
||||
/// * `preread_skip_port` — port was server-speaks-first; skipped wait
|
||||
/// * `preread_skip_unsupported` — tunnel-node said no; skipped wait
|
||||
/// A rolling sum of win-time (µs) drives a `mean_win_time` readout so
|
||||
/// you can tune `CLIENT_FIRST_DATA_WAIT` against real client flush
|
||||
/// timing. A summary line is logged every 100 preread events.
|
||||
preread_win: AtomicU64,
|
||||
preread_loss: AtomicU64,
|
||||
preread_skip_port: AtomicU64,
|
||||
preread_skip_unsupported: AtomicU64,
|
||||
preread_win_total_us: AtomicU64,
|
||||
/// Separate monotonic counter used only to trigger the summary log
|
||||
/// (avoids a race where two threads both see `total % 100 == 0`).
|
||||
preread_total_events: AtomicU64,
|
||||
}
|
||||
|
||||
impl TunnelMux {
|
||||
@@ -75,18 +124,88 @@ impl TunnelMux {
|
||||
);
|
||||
let (tx, rx) = mpsc::channel(512);
|
||||
tokio::spawn(mux_loop(rx, fronter));
|
||||
Arc::new(Self { tx })
|
||||
Arc::new(Self {
|
||||
tx,
|
||||
connect_data_unsupported: Arc::new(AtomicBool::new(false)),
|
||||
preread_win: AtomicU64::new(0),
|
||||
preread_loss: AtomicU64::new(0),
|
||||
preread_skip_port: AtomicU64::new(0),
|
||||
preread_skip_unsupported: AtomicU64::new(0),
|
||||
preread_win_total_us: AtomicU64::new(0),
|
||||
preread_total_events: AtomicU64::new(0),
|
||||
})
|
||||
}
|
||||
|
||||
async fn send(&self, msg: MuxMsg) {
|
||||
let _ = self.tx.send(msg).await;
|
||||
}
|
||||
|
||||
fn connect_data_unsupported(&self) -> bool {
|
||||
self.connect_data_unsupported.load(Ordering::Relaxed)
|
||||
}
|
||||
|
||||
fn mark_connect_data_unsupported(&self) {
|
||||
if !self.connect_data_unsupported.swap(true, Ordering::Relaxed) {
|
||||
tracing::warn!(
|
||||
"tunnel-node doesn't support connect_data (pre-v1.x); falling back to plain connect + data for all future sessions"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn record_preread_win(&self, port: u16, elapsed: Duration) {
|
||||
self.preread_win.fetch_add(1, Ordering::Relaxed);
|
||||
self.preread_win_total_us
|
||||
.fetch_add(elapsed.as_micros() as u64, Ordering::Relaxed);
|
||||
tracing::debug!("preread win: port={} took={:?}", port, elapsed);
|
||||
self.maybe_log_preread_summary();
|
||||
}
|
||||
|
||||
fn record_preread_loss(&self, port: u16) {
|
||||
self.preread_loss.fetch_add(1, Ordering::Relaxed);
|
||||
tracing::debug!("preread loss: port={} (empty within {:?})", port, CLIENT_FIRST_DATA_WAIT);
|
||||
self.maybe_log_preread_summary();
|
||||
}
|
||||
|
||||
fn record_preread_skip_port(&self, port: u16) {
|
||||
self.preread_skip_port.fetch_add(1, Ordering::Relaxed);
|
||||
tracing::debug!("preread skip: port={} (server-speaks-first)", port);
|
||||
self.maybe_log_preread_summary();
|
||||
}
|
||||
|
||||
fn record_preread_skip_unsupported(&self, port: u16) {
|
||||
self.preread_skip_unsupported.fetch_add(1, Ordering::Relaxed);
|
||||
tracing::debug!("preread skip: port={} (connect_data unsupported)", port);
|
||||
self.maybe_log_preread_summary();
|
||||
}
|
||||
|
||||
/// Emit an aggregate summary exactly once per 100 preread events.
|
||||
/// Using a dedicated counter for the trigger avoids a race where two
|
||||
/// threads both observe the win/loss/skip totals summing to a
|
||||
/// multiple of 100 — here, exactly one thread gets the boundary.
|
||||
fn maybe_log_preread_summary(&self) {
|
||||
let new_count = self.preread_total_events.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
if new_count % 100 != 0 {
|
||||
return;
|
||||
}
|
||||
let win = self.preread_win.load(Ordering::Relaxed);
|
||||
let loss = self.preread_loss.load(Ordering::Relaxed);
|
||||
let skip_port = self.preread_skip_port.load(Ordering::Relaxed);
|
||||
let skip_unsup = self.preread_skip_unsupported.load(Ordering::Relaxed);
|
||||
let total_us = self.preread_win_total_us.load(Ordering::Relaxed);
|
||||
let mean_us = if win > 0 { total_us / win } else { 0 };
|
||||
tracing::info!(
|
||||
"connect_data preread: {} win / {} loss / {} skip(port) / {} skip(unsup), mean win time {}µs (ceiling {}µs)",
|
||||
win,
|
||||
loss,
|
||||
skip_port,
|
||||
skip_unsup,
|
||||
mean_us,
|
||||
CLIENT_FIRST_DATA_WAIT.as_micros(),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async fn mux_loop(
|
||||
mut rx: mpsc::Receiver<MuxMsg>,
|
||||
fronter: Arc<DomainFronter>,
|
||||
) {
|
||||
async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>) {
|
||||
// One semaphore per deployment ID, each allowing 30 concurrent requests.
|
||||
let sems: Arc<HashMap<String, Arc<Semaphore>>> = Arc::new(
|
||||
fronter
|
||||
@@ -107,7 +226,7 @@ async fn mux_loop(
|
||||
msgs.push(msg);
|
||||
}
|
||||
|
||||
// Split: connects go parallel, data/close get batched.
|
||||
// Split: plain connects go parallel, data-bearing ops get batched.
|
||||
let mut data_ops: Vec<BatchOp> = Vec::new();
|
||||
let mut data_replies: Vec<(usize, oneshot::Sender<Result<TunnelResponse, String>>)> =
|
||||
Vec::new();
|
||||
@@ -128,6 +247,35 @@ async fn mux_loop(
|
||||
}
|
||||
});
|
||||
}
|
||||
MuxMsg::ConnectData { host, port, data, reply } => {
|
||||
let encoded = Some(B64.encode(data.as_slice()));
|
||||
let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0);
|
||||
|
||||
if !data_ops.is_empty()
|
||||
&& (data_ops.len() >= MAX_BATCH_OPS
|
||||
|| batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES)
|
||||
{
|
||||
fire_batch(
|
||||
&sems,
|
||||
&fronter,
|
||||
std::mem::take(&mut data_ops),
|
||||
std::mem::take(&mut data_replies),
|
||||
)
|
||||
.await;
|
||||
batch_payload_bytes = 0;
|
||||
}
|
||||
|
||||
let idx = data_ops.len();
|
||||
data_ops.push(BatchOp {
|
||||
op: "connect_data".into(),
|
||||
sid: None,
|
||||
host: Some(host),
|
||||
port: Some(port),
|
||||
d: encoded,
|
||||
});
|
||||
data_replies.push((idx, reply));
|
||||
batch_payload_bytes += op_bytes;
|
||||
}
|
||||
MuxMsg::Data { sid, data, reply } => {
|
||||
let encoded = if data.is_empty() {
|
||||
None
|
||||
@@ -219,7 +367,12 @@ async fn fire_batch(
|
||||
f.tunnel_batch_request_to(&script_id, &data_ops),
|
||||
)
|
||||
.await;
|
||||
tracing::info!("batch: {} ops → {}, rtt={:?}", n_ops, &script_id[..script_id.len().min(8)], t0.elapsed());
|
||||
tracing::info!(
|
||||
"batch: {} ops → {}, rtt={:?}",
|
||||
n_ops,
|
||||
&script_id[..script_id.len().min(8)],
|
||||
t0.elapsed()
|
||||
);
|
||||
|
||||
match result {
|
||||
Ok(Ok(batch_resp)) => {
|
||||
@@ -258,6 +411,95 @@ pub async fn tunnel_connection(
|
||||
port: u16,
|
||||
mux: &Arc<TunnelMux>,
|
||||
) -> std::io::Result<()> {
|
||||
// Only try the bundled connect+data optimization when it's likely to
|
||||
// pay off — client-speaks-first protocols (TLS on 443 et al.) — and
|
||||
// only if the tunnel-node has already accepted `connect_data` at least
|
||||
// once this process lifetime (or we haven't tried yet). Check the
|
||||
// fallback cache first so `skip(unsup)` shadows `skip(port)` in the
|
||||
// metrics once the feature is disabled process-wide.
|
||||
let initial_data = if mux.connect_data_unsupported() {
|
||||
mux.record_preread_skip_unsupported(port);
|
||||
None
|
||||
} else if is_server_speaks_first(port) {
|
||||
mux.record_preread_skip_port(port);
|
||||
None
|
||||
} else {
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let t0 = Instant::now();
|
||||
match tokio::time::timeout(CLIENT_FIRST_DATA_WAIT, sock.read(&mut buf)).await {
|
||||
Ok(Ok(0)) => return Ok(()),
|
||||
Ok(Ok(n)) => {
|
||||
mux.record_preread_win(port, t0.elapsed());
|
||||
buf.truncate(n);
|
||||
Some(Arc::new(buf))
|
||||
}
|
||||
Ok(Err(e)) => return Err(e),
|
||||
Err(_) => {
|
||||
mux.record_preread_loss(port);
|
||||
None
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let (sid, first_resp, pending_client_data) = match initial_data {
|
||||
Some(data) => match connect_with_initial_data(host, port, data.clone(), mux).await? {
|
||||
ConnectDataOutcome::Opened { sid, response } => (sid, Some(response), None),
|
||||
ConnectDataOutcome::Unsupported => {
|
||||
mux.mark_connect_data_unsupported();
|
||||
let sid = connect_plain(host, port, mux).await?;
|
||||
// Recover the buffered ClientHello from the Arc so the
|
||||
// first tunnel_loop iteration can replay it. The mux task
|
||||
// may still hold the other ref during the unsupported
|
||||
// reply's settle window — fall back to a clone in that
|
||||
// race (rare; the reply path drops its ref before we
|
||||
// reach here in practice).
|
||||
let bytes = Arc::try_unwrap(data).unwrap_or_else(|a| (*a).clone());
|
||||
(sid, None, Some(bytes))
|
||||
}
|
||||
},
|
||||
None => (connect_plain(host, port, mux).await?, None, None),
|
||||
};
|
||||
|
||||
tracing::info!("tunnel session {} opened for {}:{}", sid, host, port);
|
||||
|
||||
// Run the first-response write + tunnel_loop inside an async block so
|
||||
// any io-error propagates via `?` without bypassing the Close below.
|
||||
// We deliberately don't use a Drop guard for Close: a Drop impl can't
|
||||
// .await cleanly, and tokio::spawn from inside Drop is unreliable
|
||||
// during runtime shutdown. The explicit send below covers every
|
||||
// non-panic path; a panic during tunnel_loop would leak the session
|
||||
// on the tunnel-node until its 5-minute idle reaper runs.
|
||||
let result = async {
|
||||
if let Some(resp) = first_resp {
|
||||
match write_tunnel_response(&mut sock, &resp).await? {
|
||||
WriteOutcome::Wrote | WriteOutcome::NoData => {}
|
||||
WriteOutcome::BadBase64 => {
|
||||
tracing::error!("tunnel session {}: bad base64 in connect_data response", sid);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
if resp.eof.unwrap_or(false) {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
tunnel_loop(&mut sock, &sid, mux, pending_client_data).await
|
||||
}
|
||||
.await;
|
||||
|
||||
mux.send(MuxMsg::Close { sid: sid.clone() }).await;
|
||||
tracing::info!("tunnel session {} closed for {}:{}", sid, host, port);
|
||||
result
|
||||
}
|
||||
|
||||
enum ConnectDataOutcome {
|
||||
Opened {
|
||||
sid: String,
|
||||
response: TunnelResponse,
|
||||
},
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
async fn connect_plain(host: &str, port: u16, mux: &Arc<TunnelMux>) -> std::io::Result<String> {
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
mux.send(MuxMsg::Connect {
|
||||
host: host.to_string(),
|
||||
@@ -266,7 +508,7 @@ pub async fn tunnel_connection(
|
||||
})
|
||||
.await;
|
||||
|
||||
let sid = match reply_rx.await {
|
||||
match reply_rx.await {
|
||||
Ok(Ok(resp)) => {
|
||||
if let Some(ref e) = resp.e {
|
||||
tracing::error!("tunnel connect error for {}:{}: {}", host, port, e);
|
||||
@@ -277,10 +519,45 @@ pub async fn tunnel_connection(
|
||||
}
|
||||
resp.sid.ok_or_else(|| {
|
||||
std::io::Error::new(std::io::ErrorKind::Other, "tunnel connect: no session id")
|
||||
})?
|
||||
})
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
tracing::error!("tunnel connect error for {}:{}: {}", host, port, e);
|
||||
Err(std::io::Error::new(
|
||||
std::io::ErrorKind::ConnectionRefused,
|
||||
e,
|
||||
))
|
||||
}
|
||||
Err(_) => Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"mux channel closed",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn connect_with_initial_data(
|
||||
host: &str,
|
||||
port: u16,
|
||||
data: Arc<Vec<u8>>,
|
||||
mux: &Arc<TunnelMux>,
|
||||
) -> std::io::Result<ConnectDataOutcome> {
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
mux.send(MuxMsg::ConnectData {
|
||||
host: host.to_string(),
|
||||
port,
|
||||
data,
|
||||
reply: reply_tx,
|
||||
})
|
||||
.await;
|
||||
|
||||
let resp = match reply_rx.await {
|
||||
Ok(Ok(resp)) => resp,
|
||||
Ok(Err(e)) => {
|
||||
if is_connect_data_unsupported_error_str(&e) {
|
||||
tracing::debug!("connect_data unsupported for {}:{}: {}", host, port, e);
|
||||
return Ok(ConnectDataOutcome::Unsupported);
|
||||
}
|
||||
tracing::error!("tunnel connect_data error for {}:{}: {}", host, port, e);
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::ConnectionRefused,
|
||||
e,
|
||||
@@ -294,38 +571,96 @@ pub async fn tunnel_connection(
|
||||
}
|
||||
};
|
||||
|
||||
tracing::info!("tunnel session {} opened for {}:{}", sid, host, port);
|
||||
let result = tunnel_loop(&mut sock, &sid, mux).await;
|
||||
mux.send(MuxMsg::Close { sid: sid.clone() }).await;
|
||||
tracing::info!("tunnel session {} closed for {}:{}", sid, host, port);
|
||||
result
|
||||
if is_connect_data_unsupported_response(&resp) {
|
||||
tracing::debug!(
|
||||
"connect_data unsupported for {}:{}: {:?}",
|
||||
host,
|
||||
port,
|
||||
resp.e
|
||||
);
|
||||
return Ok(ConnectDataOutcome::Unsupported);
|
||||
}
|
||||
|
||||
if let Some(ref e) = resp.e {
|
||||
tracing::error!("tunnel connect_data error for {}:{}: {}", host, port, e);
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::ConnectionRefused,
|
||||
e.clone(),
|
||||
));
|
||||
}
|
||||
|
||||
let Some(sid) = resp.sid.clone() else {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::Other,
|
||||
"tunnel connect_data: no session id",
|
||||
));
|
||||
};
|
||||
|
||||
Ok(ConnectDataOutcome::Opened { sid, response: resp })
|
||||
}
|
||||
|
||||
/// Decide whether a response indicates the tunnel-node (or apps_script
|
||||
/// layer in front of it) didn't recognize `connect_data`.
|
||||
///
|
||||
/// Primary signal: the structured `code` field (`UNSUPPORTED_OP`), emitted
|
||||
/// by any tunnel-node or apps_script deployment that has this change.
|
||||
/// Fallback signal (for legacy deployments, pre-connect_data): substring
|
||||
/// match on the stable error string. The string-match is a one-way
|
||||
/// compatibility hatch — newer deployments set `code` so future refactors
|
||||
/// of the error text won't silently break detection.
|
||||
///
|
||||
/// Two error shapes are possible on the legacy path:
|
||||
/// * tunnel-node's single-op/batch handler: `"unknown op: connect_data"`
|
||||
/// * apps_script's `_doTunnel` default branch: `"unknown tunnel op: connect_data"`
|
||||
///
|
||||
/// Apps_script and tunnel-node ship on independent cadences, so it is
|
||||
/// realistic for a user to upgrade one but not the other — detection has
|
||||
/// to cover both shapes or the feature hard-fails on version skew.
|
||||
fn is_connect_data_unsupported_response(resp: &TunnelResponse) -> bool {
|
||||
if resp.code.as_deref() == Some(CODE_UNSUPPORTED_OP) {
|
||||
return true;
|
||||
}
|
||||
resp.e
|
||||
.as_deref()
|
||||
.map(is_connect_data_unsupported_error_str)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn is_connect_data_unsupported_error_str(e: &str) -> bool {
|
||||
let e = e.to_ascii_lowercase();
|
||||
(e.contains("unknown op") || e.contains("unknown tunnel op")) && e.contains("connect_data")
|
||||
}
|
||||
|
||||
async fn tunnel_loop(
|
||||
sock: &mut TcpStream,
|
||||
sid: &str,
|
||||
mux: &Arc<TunnelMux>,
|
||||
mut pending_client_data: Option<Vec<u8>>,
|
||||
) -> std::io::Result<()> {
|
||||
let (mut reader, mut writer) = sock.split();
|
||||
let mut buf = vec![0u8; 65536];
|
||||
let mut consecutive_empty = 0u32;
|
||||
|
||||
loop {
|
||||
let read_timeout = match consecutive_empty {
|
||||
0 => Duration::from_millis(20),
|
||||
1 => Duration::from_millis(80),
|
||||
2 => Duration::from_millis(200),
|
||||
_ => Duration::from_secs(30),
|
||||
};
|
||||
let client_data = if let Some(data) = pending_client_data.take() {
|
||||
Some(data)
|
||||
} else {
|
||||
let read_timeout = match consecutive_empty {
|
||||
0 => Duration::from_millis(20),
|
||||
1 => Duration::from_millis(80),
|
||||
2 => Duration::from_millis(200),
|
||||
_ => Duration::from_secs(30),
|
||||
};
|
||||
|
||||
let client_data = match tokio::time::timeout(read_timeout, reader.read(&mut buf)).await {
|
||||
Ok(Ok(0)) => break,
|
||||
Ok(Ok(n)) => {
|
||||
consecutive_empty = 0;
|
||||
Some(buf[..n].to_vec())
|
||||
match tokio::time::timeout(read_timeout, reader.read(&mut buf)).await {
|
||||
Ok(Ok(0)) => break,
|
||||
Ok(Ok(n)) => {
|
||||
consecutive_empty = 0;
|
||||
Some(buf[..n].to_vec())
|
||||
}
|
||||
Ok(Err(_)) => break,
|
||||
Err(_) => None,
|
||||
}
|
||||
Ok(Err(_)) => break,
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
if client_data.is_none() && consecutive_empty > 3 {
|
||||
@@ -364,25 +699,15 @@ async fn tunnel_loop(
|
||||
break;
|
||||
}
|
||||
|
||||
let got_data = if let Some(ref d) = resp.d {
|
||||
if !d.is_empty() {
|
||||
match B64.decode(d) {
|
||||
Ok(bytes) if !bytes.is_empty() => {
|
||||
writer.write_all(&bytes).await?;
|
||||
writer.flush().await?;
|
||||
true
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::error!("tunnel bad base64: {}", e);
|
||||
break;
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
} else {
|
||||
false
|
||||
let got_data = match write_tunnel_response(&mut writer, &resp).await? {
|
||||
WriteOutcome::Wrote => true,
|
||||
WriteOutcome::NoData => false,
|
||||
WriteOutcome::BadBase64 => {
|
||||
// Tunnel-node gave us garbage; tear the session down but
|
||||
// do NOT propagate as an io error — the caller's Close
|
||||
// guard will clean up on the tunnel-node side.
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
if resp.eof.unwrap_or(false) {
|
||||
@@ -398,3 +723,225 @@ async fn tunnel_loop(
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
enum WriteOutcome {
|
||||
Wrote,
|
||||
NoData,
|
||||
BadBase64,
|
||||
}
|
||||
|
||||
async fn write_tunnel_response<W>(
|
||||
writer: &mut W,
|
||||
resp: &TunnelResponse,
|
||||
) -> std::io::Result<WriteOutcome>
|
||||
where
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
let Some(ref d) = resp.d else {
|
||||
return Ok(WriteOutcome::NoData);
|
||||
};
|
||||
if d.is_empty() {
|
||||
return Ok(WriteOutcome::NoData);
|
||||
}
|
||||
|
||||
match B64.decode(d) {
|
||||
Ok(bytes) if !bytes.is_empty() => {
|
||||
writer.write_all(&bytes).await?;
|
||||
writer.flush().await?;
|
||||
Ok(WriteOutcome::Wrote)
|
||||
}
|
||||
Ok(_) => Ok(WriteOutcome::NoData),
|
||||
Err(e) => {
|
||||
tracing::error!("tunnel bad base64: {}", e);
|
||||
Ok(WriteOutcome::BadBase64)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn resp_with(code: Option<&str>, e: Option<&str>) -> TunnelResponse {
|
||||
TunnelResponse {
|
||||
sid: None,
|
||||
d: None,
|
||||
eof: None,
|
||||
e: e.map(str::to_string),
|
||||
code: code.map(str::to_string),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unsupported_detection_via_structured_code() {
|
||||
assert!(is_connect_data_unsupported_response(&resp_with(Some("UNSUPPORTED_OP"), None)));
|
||||
assert!(is_connect_data_unsupported_response(&resp_with(
|
||||
Some("UNSUPPORTED_OP"),
|
||||
Some("unknown op: connect_data"),
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unsupported_detection_via_legacy_tunnel_node_string() {
|
||||
// Pre-change tunnel-node: no code field, bare "unknown op: ...".
|
||||
assert!(is_connect_data_unsupported_response(&resp_with(
|
||||
None, Some("unknown op: connect_data"),
|
||||
)));
|
||||
assert!(is_connect_data_unsupported_response(&resp_with(
|
||||
None, Some("Unknown Op: CONNECT_DATA"),
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unsupported_detection_via_legacy_apps_script_string() {
|
||||
// Pre-change apps_script: default branch emits "unknown tunnel op: ...".
|
||||
// This is the realistic skew case — user upgrades tunnel-node + client
|
||||
// binary but hasn't redeployed the Apps Script yet.
|
||||
assert!(is_connect_data_unsupported_response(&resp_with(
|
||||
None, Some("unknown tunnel op: connect_data"),
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unsupported_detection_rejects_unrelated_errors() {
|
||||
assert!(!is_connect_data_unsupported_response(&resp_with(
|
||||
None, Some("connect failed: refused"),
|
||||
)));
|
||||
assert!(!is_connect_data_unsupported_response(&resp_with(None, Some("bad base64"))));
|
||||
assert!(!is_connect_data_unsupported_response(&resp_with(None, None)));
|
||||
// "connect_data" alone (without "unknown op") shouldn't trigger.
|
||||
assert!(!is_connect_data_unsupported_response(&resp_with(
|
||||
None, Some("connect_data: bad port"),
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn server_speaks_first_covers_common_protocols() {
|
||||
for p in [21u16, 22, 25, 80, 110, 143, 587] {
|
||||
assert!(is_server_speaks_first(p), "port {} should be server-first", p);
|
||||
}
|
||||
for p in [443u16, 8443, 853, 993, 1234] {
|
||||
assert!(!is_server_speaks_first(p), "port {} should NOT be server-first", p);
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a TunnelMux whose send channel is exposed to the test rather
|
||||
/// than wired to a real DomainFronter. Lets tests assert what messages
|
||||
/// the client would emit without needing network or apps_script.
|
||||
fn mux_for_test() -> (Arc<TunnelMux>, mpsc::Receiver<MuxMsg>) {
|
||||
let (tx, rx) = mpsc::channel(16);
|
||||
let mux = Arc::new(TunnelMux {
|
||||
tx,
|
||||
connect_data_unsupported: Arc::new(AtomicBool::new(false)),
|
||||
preread_win: AtomicU64::new(0),
|
||||
preread_loss: AtomicU64::new(0),
|
||||
preread_skip_port: AtomicU64::new(0),
|
||||
preread_skip_unsupported: AtomicU64::new(0),
|
||||
preread_win_total_us: AtomicU64::new(0),
|
||||
preread_total_events: AtomicU64::new(0),
|
||||
});
|
||||
(mux, rx)
|
||||
}
|
||||
|
||||
/// The buffered ClientHello from the pre-read window must reach the
|
||||
/// tunnel-node as the first `Data` op on the fallback path. If this
|
||||
/// regresses, every TLS handshake stalls until the 30 s read-timeout
|
||||
/// fires — catastrophic and silent without a test.
|
||||
#[tokio::test]
|
||||
async fn tunnel_loop_replays_pending_client_data_before_reading_socket() {
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
// Set up a loopback pair so tunnel_loop has a real TcpStream to
|
||||
// work with. We never write to its peer, so tunnel_loop's "read
|
||||
// from client" branch would block indefinitely — meaning any
|
||||
// `Data` msg it emits must have come from pending_client_data.
|
||||
let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
|
||||
let addr = listener.local_addr().unwrap();
|
||||
let accept = tokio::spawn(async move { listener.accept().await.unwrap().0 });
|
||||
let _client = TcpStream::connect(addr).await.unwrap();
|
||||
let mut server_side = accept.await.unwrap();
|
||||
|
||||
let (mux, mut rx) = mux_for_test();
|
||||
let pending = Some(b"CLIENTHELLO".to_vec());
|
||||
|
||||
let loop_handle = tokio::spawn({
|
||||
let mux = mux.clone();
|
||||
async move {
|
||||
tunnel_loop(&mut server_side, "sid-under-test", &mux, pending).await
|
||||
}
|
||||
});
|
||||
|
||||
// The first message tunnel_loop emits must be Data carrying the
|
||||
// replayed bytes — NOT whatever it would have read from the socket.
|
||||
let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
|
||||
.await
|
||||
.expect("tunnel_loop did not send a message within 2s")
|
||||
.expect("mux channel closed unexpectedly");
|
||||
|
||||
match msg {
|
||||
MuxMsg::Data { sid, data, reply } => {
|
||||
assert_eq!(sid, "sid-under-test");
|
||||
assert_eq!(&data[..], b"CLIENTHELLO");
|
||||
// Reply with eof so tunnel_loop unwinds cleanly.
|
||||
let _ = reply.send(Ok(TunnelResponse {
|
||||
sid: Some("sid-under-test".into()),
|
||||
d: None,
|
||||
eof: Some(true),
|
||||
e: None,
|
||||
code: None,
|
||||
}));
|
||||
}
|
||||
other => panic!(
|
||||
"first mux message was not Data (expected replay); got {:?}",
|
||||
match other {
|
||||
MuxMsg::Connect { .. } => "Connect",
|
||||
MuxMsg::ConnectData { .. } => "ConnectData",
|
||||
MuxMsg::Data { .. } => unreachable!(),
|
||||
MuxMsg::Close { .. } => "Close",
|
||||
}
|
||||
),
|
||||
}
|
||||
|
||||
let _ = tokio::time::timeout(Duration::from_secs(2), loop_handle)
|
||||
.await
|
||||
.expect("tunnel_loop did not exit after eof");
|
||||
}
|
||||
|
||||
/// Once `mark_connect_data_unsupported` is called, future sessions
|
||||
/// must see the flag — no per-session repeat of the detect-and-fallback
|
||||
/// cost. If this regresses, every new flow pays an extra round trip
|
||||
/// against a tunnel-node that will never learn the new op.
|
||||
#[test]
|
||||
fn unsupported_cache_is_sticky() {
|
||||
let (mux, _rx) = mux_for_test();
|
||||
assert!(!mux.connect_data_unsupported());
|
||||
mux.mark_connect_data_unsupported();
|
||||
assert!(mux.connect_data_unsupported());
|
||||
mux.mark_connect_data_unsupported(); // idempotent
|
||||
assert!(mux.connect_data_unsupported());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn preread_counters_track_each_outcome() {
|
||||
let (mux, _rx) = mux_for_test();
|
||||
|
||||
mux.record_preread_win(443, Duration::from_micros(3_500));
|
||||
mux.record_preread_win(443, Duration::from_micros(1_500));
|
||||
mux.record_preread_loss(443);
|
||||
mux.record_preread_skip_port(80);
|
||||
mux.record_preread_skip_unsupported(443);
|
||||
|
||||
assert_eq!(mux.preread_win.load(Ordering::Relaxed), 2);
|
||||
assert_eq!(mux.preread_loss.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(mux.preread_skip_port.load(Ordering::Relaxed), 1);
|
||||
assert_eq!(mux.preread_skip_unsupported.load(Ordering::Relaxed), 1);
|
||||
// Two wins summing to 5000 µs.
|
||||
assert_eq!(mux.preread_win_total_us.load(Ordering::Relaxed), 5_000);
|
||||
// Five record_* calls, so trigger counter is at 5.
|
||||
assert_eq!(mux.preread_total_events.load(Ordering::Relaxed), 5);
|
||||
}
|
||||
}
|
||||
|
||||
+299
-18
@@ -26,6 +26,12 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
/// Structured error code returned when the tunnel-node receives an op it
|
||||
/// doesn't recognize. Clients use this (rather than string-matching `e`) to
|
||||
/// detect a version mismatch and gracefully fall back.
|
||||
const CODE_UNSUPPORTED_OP: &str = "UNSUPPORTED_OP";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Session
|
||||
@@ -137,17 +143,25 @@ struct TunnelRequest {
|
||||
#[serde(default)] data: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
#[derive(Serialize, Clone, Debug)]
|
||||
struct TunnelResponse {
|
||||
#[serde(skip_serializing_if = "Option::is_none")] sid: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")] d: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")] eof: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")] e: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")] code: Option<String>,
|
||||
}
|
||||
|
||||
impl TunnelResponse {
|
||||
fn error(msg: impl Into<String>) -> Self {
|
||||
Self { sid: None, d: None, eof: None, e: Some(msg.into()) }
|
||||
Self { sid: None, d: None, eof: None, e: Some(msg.into()), code: None }
|
||||
}
|
||||
fn unsupported_op(op: &str) -> Self {
|
||||
Self {
|
||||
sid: None, d: None, eof: None,
|
||||
e: Some(format!("unknown op: {}", op)),
|
||||
code: Some(CODE_UNSUPPORTED_OP.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -188,9 +202,12 @@ async fn handle_tunnel(
|
||||
}
|
||||
match req.op.as_str() {
|
||||
"connect" => Json(handle_connect(&state, req.host, req.port).await),
|
||||
"connect_data" => {
|
||||
Json(handle_connect_data_single(&state, req.host, req.port, req.data).await)
|
||||
}
|
||||
"data" => Json(handle_data_single(&state, req.sid, req.data).await),
|
||||
"close" => Json(handle_close(&state, req.sid).await),
|
||||
other => Json(TunnelResponse::error(format!("unknown op: {}", other))),
|
||||
other => Json(TunnelResponse::unsupported_op(other)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,15 +255,51 @@ async fn handle_batch(
|
||||
// then do a short sleep to let servers respond, then drain all.
|
||||
// This batches the network round trips on the server side too.
|
||||
|
||||
// Phase 1: process connects and writes
|
||||
// Phase 1: process connects and writes.
|
||||
//
|
||||
// `connect` and `connect_data` ops each establish a brand-new upstream
|
||||
// TCP connection which can take up to 10 s (create_session timeout).
|
||||
// Running them inline head-of-line-blocks every other op in the batch,
|
||||
// so we dispatch both into a JoinSet and await them concurrently below.
|
||||
//
|
||||
// `connect_data` is expected to dominate in practice (new client) but
|
||||
// we still hit `connect` from older clients or from server-speaks-first
|
||||
// ports that skip the pre-read — if a slow `connect` landed in the same
|
||||
// batch as data-bearing ops it could stall everyone.
|
||||
let mut results: Vec<(usize, TunnelResponse)> = Vec::with_capacity(req.ops.len());
|
||||
let mut data_ops: Vec<(usize, String)> = Vec::new(); // (index, sid) for data ops needing drain
|
||||
|
||||
enum NewConn {
|
||||
Connect(TunnelResponse),
|
||||
ConnectData(Result<String, TunnelResponse>),
|
||||
}
|
||||
let mut new_conn_jobs: JoinSet<(usize, NewConn)> = JoinSet::new();
|
||||
|
||||
for (i, op) in req.ops.iter().enumerate() {
|
||||
match op.op.as_str() {
|
||||
"connect" => {
|
||||
let r = handle_connect(&state, op.host.clone(), op.port).await;
|
||||
results.push((i, r));
|
||||
let state = state.clone();
|
||||
let host = op.host.clone();
|
||||
let port = op.port;
|
||||
new_conn_jobs.spawn(async move {
|
||||
(i, NewConn::Connect(handle_connect(&state, host, port).await))
|
||||
});
|
||||
}
|
||||
"connect_data" => {
|
||||
let state = state.clone();
|
||||
let host = op.host.clone();
|
||||
let port = op.port;
|
||||
let d = op.d.clone();
|
||||
new_conn_jobs.spawn(async move {
|
||||
// Drop the returned Arc<SessionInner>: phase 2 below
|
||||
// holds the sessions-map lock once for the whole batch
|
||||
// and re-looks up each sid, which is cheap. The Arc
|
||||
// return is a convenience for the single-op path only.
|
||||
let r = handle_connect_data_phase1(&state, host, port, d)
|
||||
.await
|
||||
.map(|(sid, _inner)| sid);
|
||||
(i, NewConn::ConnectData(r))
|
||||
});
|
||||
}
|
||||
"data" => {
|
||||
let sid = match &op.sid {
|
||||
@@ -273,7 +326,9 @@ async fn handle_batch(
|
||||
data_ops.push((i, sid));
|
||||
} else {
|
||||
drop(sessions);
|
||||
results.push((i, TunnelResponse { sid: Some(sid), d: None, eof: Some(true), e: None }));
|
||||
results.push((i, TunnelResponse {
|
||||
sid: Some(sid), d: None, eof: Some(true), e: None, code: None,
|
||||
}));
|
||||
}
|
||||
}
|
||||
"close" => {
|
||||
@@ -281,7 +336,21 @@ async fn handle_batch(
|
||||
results.push((i, r));
|
||||
}
|
||||
other => {
|
||||
results.push((i, TunnelResponse::error(format!("unknown op: {}", other))));
|
||||
results.push((i, TunnelResponse::unsupported_op(other)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Await all concurrent connect / connect_data jobs. For connect_data,
|
||||
// successful ones join the data-drain set in phase 2; plain connects
|
||||
// go straight to results because they have no initial data to drain.
|
||||
while let Some(join) = new_conn_jobs.join_next().await {
|
||||
match join {
|
||||
Ok((i, NewConn::Connect(r))) => results.push((i, r)),
|
||||
Ok((i, NewConn::ConnectData(Ok(sid)))) => data_ops.push((i, sid)),
|
||||
Ok((i, NewConn::ConnectData(Err(r)))) => results.push((i, r)),
|
||||
Err(e) => {
|
||||
tracing::error!("new-connection task panicked: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -304,12 +373,12 @@ async fn handle_batch(
|
||||
results.push((*i, TunnelResponse {
|
||||
sid: Some(sid.clone()),
|
||||
d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
|
||||
eof: Some(eof), e: None,
|
||||
eof: Some(eof), e: None, code: None,
|
||||
}));
|
||||
}
|
||||
} else {
|
||||
results.push((*i, TunnelResponse {
|
||||
sid: Some(sid.clone()), d: None, eof: Some(true), e: None,
|
||||
sid: Some(sid.clone()), d: None, eof: Some(true), e: None, code: None,
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -325,11 +394,11 @@ async fn handle_batch(
|
||||
results.push((*i, TunnelResponse {
|
||||
sid: Some(sid.clone()),
|
||||
d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
|
||||
eof: Some(eof), e: None,
|
||||
eof: Some(eof), e: None, code: None,
|
||||
}));
|
||||
} else {
|
||||
results.push((*i, TunnelResponse {
|
||||
sid: Some(sid.clone()), d: None, eof: Some(true), e: None,
|
||||
sid: Some(sid.clone()), d: None, eof: Some(true), e: None, code: None,
|
||||
}));
|
||||
}
|
||||
}
|
||||
@@ -372,14 +441,25 @@ fn decompress_gzip(data: &[u8]) -> Result<Vec<u8>, String> {
|
||||
// Shared op handlers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async fn handle_connect(state: &AppState, host: Option<String>, port: Option<u16>) -> TunnelResponse {
|
||||
fn validate_host_port(
|
||||
host: Option<String>,
|
||||
port: Option<u16>,
|
||||
) -> Result<(String, u16), TunnelResponse> {
|
||||
let host = match host {
|
||||
Some(h) if !h.is_empty() => h,
|
||||
_ => return TunnelResponse::error("missing host"),
|
||||
_ => return Err(TunnelResponse::error("missing host")),
|
||||
};
|
||||
let port = match port {
|
||||
Some(p) if p > 0 => p,
|
||||
_ => return TunnelResponse::error("missing or invalid port"),
|
||||
_ => return Err(TunnelResponse::error("missing or invalid port")),
|
||||
};
|
||||
Ok((host, port))
|
||||
}
|
||||
|
||||
async fn handle_connect(state: &AppState, host: Option<String>, port: Option<u16>) -> TunnelResponse {
|
||||
let (host, port) = match validate_host_port(host, port) {
|
||||
Ok(v) => v,
|
||||
Err(r) => return r,
|
||||
};
|
||||
let session = match create_session(&host, port).await {
|
||||
Ok(s) => s,
|
||||
@@ -388,7 +468,82 @@ async fn handle_connect(state: &AppState, host: Option<String>, port: Option<u16
|
||||
let sid = uuid::Uuid::new_v4().to_string();
|
||||
tracing::info!("session {} -> {}:{}", sid, host, port);
|
||||
state.sessions.lock().await.insert(sid.clone(), session);
|
||||
TunnelResponse { sid: Some(sid), d: None, eof: Some(false), e: None }
|
||||
TunnelResponse { sid: Some(sid), d: None, eof: Some(false), e: None, code: None }
|
||||
}
|
||||
|
||||
/// Open a session and write the client's first bytes in one round trip.
|
||||
/// Returns the new sid plus an `Arc<SessionInner>` so unary callers
|
||||
/// (`handle_connect_data_single`) can drain the first response without a
|
||||
/// second sessions-map lookup. The batch caller drops the Arc — it takes
|
||||
/// a single lock across all drain-bound sessions in phase 2, which is
|
||||
/// cheaper than the Arc plumbing would be.
|
||||
async fn handle_connect_data_phase1(
|
||||
state: &AppState,
|
||||
host: Option<String>,
|
||||
port: Option<u16>,
|
||||
data: Option<String>,
|
||||
) -> Result<(String, Arc<SessionInner>), TunnelResponse> {
|
||||
let (host, port) = validate_host_port(host, port)?;
|
||||
|
||||
let session = create_session(&host, port)
|
||||
.await
|
||||
.map_err(|e| TunnelResponse::error(format!("connect failed: {}", e)))?;
|
||||
|
||||
// Any failure below this point must abort the reader task, otherwise
|
||||
// the newly-opened upstream TCP connection would leak. Keep the
|
||||
// abort paths explicit rather than burying them in `.map_err`.
|
||||
if let Some(ref data_b64) = data {
|
||||
if !data_b64.is_empty() {
|
||||
let bytes = match B64.decode(data_b64) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
session.reader_handle.abort();
|
||||
return Err(TunnelResponse::error(format!("bad base64: {}", e)));
|
||||
}
|
||||
};
|
||||
if !bytes.is_empty() {
|
||||
let mut w = session.inner.writer.lock().await;
|
||||
if let Err(e) = w.write_all(&bytes).await {
|
||||
drop(w);
|
||||
session.reader_handle.abort();
|
||||
return Err(TunnelResponse::error(format!("write failed: {}", e)));
|
||||
}
|
||||
let _ = w.flush().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let inner = session.inner.clone();
|
||||
let sid = uuid::Uuid::new_v4().to_string();
|
||||
tracing::info!("session {} -> {}:{} (connect_data)", sid, host, port);
|
||||
state.sessions.lock().await.insert(sid.clone(), session);
|
||||
Ok((sid, inner))
|
||||
}
|
||||
|
||||
async fn handle_connect_data_single(
|
||||
state: &AppState,
|
||||
host: Option<String>,
|
||||
port: Option<u16>,
|
||||
data: Option<String>,
|
||||
) -> TunnelResponse {
|
||||
let (sid, inner) = match handle_connect_data_phase1(state, host, port, data).await {
|
||||
Ok(v) => v,
|
||||
Err(r) => return r,
|
||||
};
|
||||
let (data, eof) = wait_and_drain(&inner, Duration::from_secs(5)).await;
|
||||
if eof {
|
||||
if let Some(s) = state.sessions.lock().await.remove(&sid) {
|
||||
s.reader_handle.abort();
|
||||
tracing::info!("session {} closed by remote", sid);
|
||||
}
|
||||
}
|
||||
TunnelResponse {
|
||||
sid: Some(sid),
|
||||
d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
|
||||
eof: Some(eof),
|
||||
e: None,
|
||||
code: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<String>) -> TunnelResponse {
|
||||
@@ -428,7 +583,7 @@ async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<
|
||||
TunnelResponse {
|
||||
sid: Some(sid),
|
||||
d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
|
||||
eof: Some(eof), e: None,
|
||||
eof: Some(eof), e: None, code: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -441,7 +596,7 @@ async fn handle_close(state: &AppState, sid: Option<String>) -> TunnelResponse {
|
||||
s.reader_handle.abort();
|
||||
tracing::info!("session {} closed by client", sid);
|
||||
}
|
||||
TunnelResponse { sid: Some(sid), d: None, eof: Some(true), e: None }
|
||||
TunnelResponse { sid: Some(sid), d: None, eof: Some(true), e: None, code: None }
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -520,3 +675,129 @@ async fn main() {
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
fn fresh_state() -> AppState {
|
||||
AppState {
|
||||
sessions: Arc::new(Mutex::new(HashMap::new())),
|
||||
auth_key: "test-key".into(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Spin up a one-shot TCP server that echoes everything it reads back
|
||||
/// with a `"ECHO: "` prefix, then returns the bound port.
|
||||
async fn start_echo_server() -> u16 {
|
||||
let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
|
||||
let port = listener.local_addr().unwrap().port();
|
||||
tokio::spawn(async move {
|
||||
if let Ok((mut sock, _)) = listener.accept().await {
|
||||
let mut buf = [0u8; 1024];
|
||||
if let Ok(n) = sock.read(&mut buf).await {
|
||||
let mut out = b"ECHO: ".to_vec();
|
||||
out.extend_from_slice(&buf[..n]);
|
||||
let _ = sock.write_all(&out).await;
|
||||
let _ = sock.flush().await;
|
||||
}
|
||||
}
|
||||
});
|
||||
port
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unsupported_op_response_has_structured_code() {
|
||||
let resp = TunnelResponse::unsupported_op("connect_data");
|
||||
assert_eq!(resp.code.as_deref(), Some(CODE_UNSUPPORTED_OP));
|
||||
assert_eq!(resp.e.as_deref(), Some("unknown op: connect_data"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn validate_host_port_rejects_empty_and_zero() {
|
||||
assert!(validate_host_port(None, Some(443)).is_err());
|
||||
assert!(validate_host_port(Some("".into()), Some(443)).is_err());
|
||||
assert!(validate_host_port(Some("x".into()), None).is_err());
|
||||
assert!(validate_host_port(Some("x".into()), Some(0)).is_err());
|
||||
assert_eq!(
|
||||
validate_host_port(Some("host".into()), Some(443)).unwrap(),
|
||||
("host".to_string(), 443),
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_data_phase1_writes_initial_data_and_returns_inner() {
|
||||
let port = start_echo_server().await;
|
||||
let state = fresh_state();
|
||||
|
||||
let (sid, inner) = handle_connect_data_phase1(
|
||||
&state,
|
||||
Some("127.0.0.1".into()),
|
||||
Some(port),
|
||||
Some(B64.encode(b"hello")),
|
||||
)
|
||||
.await
|
||||
.expect("phase1 should succeed");
|
||||
|
||||
// Session was inserted.
|
||||
assert!(state.sessions.lock().await.contains_key(&sid));
|
||||
|
||||
// Echo server sent back "ECHO: hello". Use wait_and_drain on the
|
||||
// returned Arc — no map re-lookup needed (this is the fix).
|
||||
let (data, _eof) = wait_and_drain(&inner, Duration::from_secs(2)).await;
|
||||
assert_eq!(&data[..], b"ECHO: hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_data_single_bundles_connect_and_first_bytes() {
|
||||
let port = start_echo_server().await;
|
||||
let state = fresh_state();
|
||||
|
||||
let resp = handle_connect_data_single(
|
||||
&state,
|
||||
Some("127.0.0.1".into()),
|
||||
Some(port),
|
||||
Some(B64.encode(b"world")),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(resp.e.is_none(), "unexpected error: {:?}", resp.e);
|
||||
assert!(resp.sid.is_some());
|
||||
let decoded = B64.decode(resp.d.unwrap()).unwrap();
|
||||
assert_eq!(&decoded[..], b"ECHO: world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_data_rejects_missing_host() {
|
||||
let state = fresh_state();
|
||||
let resp = handle_connect_data_single(
|
||||
&state, None, Some(443), Some(B64.encode(b"x")),
|
||||
).await;
|
||||
assert!(resp.e.as_deref().unwrap_or("").contains("missing host"));
|
||||
assert!(state.sessions.lock().await.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_data_rejects_bad_base64_and_does_not_leak_session() {
|
||||
// Need a live target so we reach the base64-decode step after
|
||||
// create_session succeeds — otherwise we'd fail earlier.
|
||||
let port = start_echo_server().await;
|
||||
let state = fresh_state();
|
||||
let resp = handle_connect_data_single(
|
||||
&state,
|
||||
Some("127.0.0.1".into()),
|
||||
Some(port),
|
||||
Some("!!!not base64!!!".into()),
|
||||
)
|
||||
.await;
|
||||
assert!(resp.e.as_deref().unwrap_or("").contains("bad base64"));
|
||||
// Session should NOT be in the map since phase1 rejected it.
|
||||
assert!(state.sessions.lock().await.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user