mirror of
https://github.com/therealaleph/MasterHttpRelayVPN-RUST.git
synced 2026-05-18 07:44:47 +03:00
feat(tunnel): save one RTT per new HTTPS flow via connect_data op
This commit is contained in:
@@ -56,6 +56,12 @@ function _doTunnel(req) {
|
|||||||
payload.host = req.h;
|
payload.host = req.h;
|
||||||
payload.port = req.p;
|
payload.port = req.p;
|
||||||
break;
|
break;
|
||||||
|
case "connect_data":
|
||||||
|
payload.op = "connect_data";
|
||||||
|
payload.host = req.h;
|
||||||
|
payload.port = req.p;
|
||||||
|
if (req.d) payload.data = req.d;
|
||||||
|
break;
|
||||||
case "data":
|
case "data":
|
||||||
payload.op = "data";
|
payload.op = "data";
|
||||||
payload.sid = req.sid;
|
payload.sid = req.sid;
|
||||||
@@ -66,7 +72,10 @@ function _doTunnel(req) {
|
|||||||
payload.sid = req.sid;
|
payload.sid = req.sid;
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
return _json({ e: "unknown tunnel op: " + req.t });
|
// Structured `code` lets the Rust client detect version skew
|
||||||
|
// without substring-matching the error text. Must match
|
||||||
|
// CODE_UNSUPPORTED_OP in tunnel_client.rs and tunnel-node/src/main.rs.
|
||||||
|
return _json({ e: "unknown tunnel op: " + req.t, code: "UNSUPPORTED_OP" });
|
||||||
}
|
}
|
||||||
|
|
||||||
var resp = UrlFetchApp.fetch(TUNNEL_SERVER_URL + "/tunnel", {
|
var resp = UrlFetchApp.fetch(TUNNEL_SERVER_URL + "/tunnel", {
|
||||||
|
|||||||
@@ -180,6 +180,11 @@ pub struct TunnelResponse {
|
|||||||
pub eof: Option<bool>,
|
pub eof: Option<bool>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub e: Option<String>,
|
pub e: Option<String>,
|
||||||
|
/// Structured error code from the tunnel-node (e.g. `UNSUPPORTED_OP`).
|
||||||
|
/// `None` for legacy tunnel-nodes; clients should fall back to parsing
|
||||||
|
/// `e` only when this is `None` and compatibility is needed.
|
||||||
|
#[serde(default)]
|
||||||
|
pub code: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// A single op in a batch tunnel request.
|
/// A single op in a batch tunnel request.
|
||||||
|
|||||||
+594
-47
@@ -6,12 +6,13 @@
|
|||||||
//! 30 in-flight requests — matching the per-account Apps Script limit.
|
//! 30 in-flight requests — matching the per-account Apps Script limit.
|
||||||
|
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use std::time::{Duration, Instant};
|
||||||
|
|
||||||
use base64::engine::general_purpose::STANDARD as B64;
|
use base64::engine::general_purpose::STANDARD as B64;
|
||||||
use base64::Engine;
|
use base64::Engine;
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::sync::{mpsc, oneshot, Semaphore};
|
use tokio::sync::{mpsc, oneshot, Semaphore};
|
||||||
|
|
||||||
@@ -41,6 +42,25 @@ const BATCH_TIMEOUT: Duration = Duration::from_secs(30);
|
|||||||
/// blocking indefinitely.
|
/// blocking indefinitely.
|
||||||
const REPLY_TIMEOUT: Duration = Duration::from_secs(35);
|
const REPLY_TIMEOUT: Duration = Duration::from_secs(35);
|
||||||
|
|
||||||
|
/// How long we'll briefly hold the client socket after the local
|
||||||
|
/// CONNECT/SOCKS5 handshake, waiting for the client's first bytes (the
|
||||||
|
/// TLS ClientHello for HTTPS). Bundling those bytes with the tunnel-node
|
||||||
|
/// connect saves one Apps Script round-trip per new flow.
|
||||||
|
const CLIENT_FIRST_DATA_WAIT: Duration = Duration::from_millis(50);
|
||||||
|
|
||||||
|
/// Structured error code the tunnel-node returns when it doesn't know the
|
||||||
|
/// op (version mismatch). Must match `tunnel-node/src/main.rs`.
|
||||||
|
const CODE_UNSUPPORTED_OP: &str = "UNSUPPORTED_OP";
|
||||||
|
|
||||||
|
/// Ports where the *server* speaks first (SMTP banner, SSH identification,
|
||||||
|
/// POP3/IMAP greeting, FTP banner). On these, waiting for client bytes
|
||||||
|
/// gains nothing and just adds handshake latency — skip the pre-read.
|
||||||
|
/// HTTP on 80 also qualifies because a naive HTTP client may not flush
|
||||||
|
/// the request line immediately after the CONNECT reply.
|
||||||
|
fn is_server_speaks_first(port: u16) -> bool {
|
||||||
|
matches!(port, 21 | 22 | 25 | 80 | 110 | 143 | 587)
|
||||||
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Multiplexer
|
// Multiplexer
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -51,6 +71,14 @@ enum MuxMsg {
|
|||||||
port: u16,
|
port: u16,
|
||||||
reply: oneshot::Sender<Result<TunnelResponse, String>>,
|
reply: oneshot::Sender<Result<TunnelResponse, String>>,
|
||||||
},
|
},
|
||||||
|
ConnectData {
|
||||||
|
host: String,
|
||||||
|
port: u16,
|
||||||
|
// Arc so the caller can hand the buffer to the mux AND keep a ref
|
||||||
|
// for the fallback path without an extra 64 KB copy per session.
|
||||||
|
data: Arc<Vec<u8>>,
|
||||||
|
reply: oneshot::Sender<Result<TunnelResponse, String>>,
|
||||||
|
},
|
||||||
Data {
|
Data {
|
||||||
sid: String,
|
sid: String,
|
||||||
data: Vec<u8>,
|
data: Vec<u8>,
|
||||||
@@ -63,6 +91,27 @@ enum MuxMsg {
|
|||||||
|
|
||||||
pub struct TunnelMux {
|
pub struct TunnelMux {
|
||||||
tx: mpsc::Sender<MuxMsg>,
|
tx: mpsc::Sender<MuxMsg>,
|
||||||
|
/// Set to `true` after the first time the tunnel-node rejects
|
||||||
|
/// `connect_data` as unsupported. Subsequent sessions skip the
|
||||||
|
/// optimistic path entirely and go straight to plain connect + data.
|
||||||
|
connect_data_unsupported: Arc<AtomicBool>,
|
||||||
|
/// Pre-read observability. Lets an operator see whether the 50 ms
|
||||||
|
/// wait-for-first-bytes is pulling its weight:
|
||||||
|
/// * `preread_win` — client sent bytes in time, bundled with connect
|
||||||
|
/// * `preread_loss` — timed out empty; paid 50 ms for nothing
|
||||||
|
/// * `preread_skip_port` — port was server-speaks-first; skipped wait
|
||||||
|
/// * `preread_skip_unsupported` — tunnel-node said no; skipped wait
|
||||||
|
/// A rolling sum of win-time (µs) drives a `mean_win_time` readout so
|
||||||
|
/// you can tune `CLIENT_FIRST_DATA_WAIT` against real client flush
|
||||||
|
/// timing. A summary line is logged every 100 preread events.
|
||||||
|
preread_win: AtomicU64,
|
||||||
|
preread_loss: AtomicU64,
|
||||||
|
preread_skip_port: AtomicU64,
|
||||||
|
preread_skip_unsupported: AtomicU64,
|
||||||
|
preread_win_total_us: AtomicU64,
|
||||||
|
/// Separate monotonic counter used only to trigger the summary log
|
||||||
|
/// (avoids a race where two threads both see `total % 100 == 0`).
|
||||||
|
preread_total_events: AtomicU64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TunnelMux {
|
impl TunnelMux {
|
||||||
@@ -75,18 +124,88 @@ impl TunnelMux {
|
|||||||
);
|
);
|
||||||
let (tx, rx) = mpsc::channel(512);
|
let (tx, rx) = mpsc::channel(512);
|
||||||
tokio::spawn(mux_loop(rx, fronter));
|
tokio::spawn(mux_loop(rx, fronter));
|
||||||
Arc::new(Self { tx })
|
Arc::new(Self {
|
||||||
|
tx,
|
||||||
|
connect_data_unsupported: Arc::new(AtomicBool::new(false)),
|
||||||
|
preread_win: AtomicU64::new(0),
|
||||||
|
preread_loss: AtomicU64::new(0),
|
||||||
|
preread_skip_port: AtomicU64::new(0),
|
||||||
|
preread_skip_unsupported: AtomicU64::new(0),
|
||||||
|
preread_win_total_us: AtomicU64::new(0),
|
||||||
|
preread_total_events: AtomicU64::new(0),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send(&self, msg: MuxMsg) {
|
async fn send(&self, msg: MuxMsg) {
|
||||||
let _ = self.tx.send(msg).await;
|
let _ = self.tx.send(msg).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn connect_data_unsupported(&self) -> bool {
|
||||||
|
self.connect_data_unsupported.load(Ordering::Relaxed)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn mark_connect_data_unsupported(&self) {
|
||||||
|
if !self.connect_data_unsupported.swap(true, Ordering::Relaxed) {
|
||||||
|
tracing::warn!(
|
||||||
|
"tunnel-node doesn't support connect_data (pre-v1.x); falling back to plain connect + data for all future sessions"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_preread_win(&self, port: u16, elapsed: Duration) {
|
||||||
|
self.preread_win.fetch_add(1, Ordering::Relaxed);
|
||||||
|
self.preread_win_total_us
|
||||||
|
.fetch_add(elapsed.as_micros() as u64, Ordering::Relaxed);
|
||||||
|
tracing::debug!("preread win: port={} took={:?}", port, elapsed);
|
||||||
|
self.maybe_log_preread_summary();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_preread_loss(&self, port: u16) {
|
||||||
|
self.preread_loss.fetch_add(1, Ordering::Relaxed);
|
||||||
|
tracing::debug!("preread loss: port={} (empty within {:?})", port, CLIENT_FIRST_DATA_WAIT);
|
||||||
|
self.maybe_log_preread_summary();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_preread_skip_port(&self, port: u16) {
|
||||||
|
self.preread_skip_port.fetch_add(1, Ordering::Relaxed);
|
||||||
|
tracing::debug!("preread skip: port={} (server-speaks-first)", port);
|
||||||
|
self.maybe_log_preread_summary();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn record_preread_skip_unsupported(&self, port: u16) {
|
||||||
|
self.preread_skip_unsupported.fetch_add(1, Ordering::Relaxed);
|
||||||
|
tracing::debug!("preread skip: port={} (connect_data unsupported)", port);
|
||||||
|
self.maybe_log_preread_summary();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Emit an aggregate summary exactly once per 100 preread events.
|
||||||
|
/// Using a dedicated counter for the trigger avoids a race where two
|
||||||
|
/// threads both observe the win/loss/skip totals summing to a
|
||||||
|
/// multiple of 100 — here, exactly one thread gets the boundary.
|
||||||
|
fn maybe_log_preread_summary(&self) {
|
||||||
|
let new_count = self.preread_total_events.fetch_add(1, Ordering::Relaxed) + 1;
|
||||||
|
if new_count % 100 != 0 {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let win = self.preread_win.load(Ordering::Relaxed);
|
||||||
|
let loss = self.preread_loss.load(Ordering::Relaxed);
|
||||||
|
let skip_port = self.preread_skip_port.load(Ordering::Relaxed);
|
||||||
|
let skip_unsup = self.preread_skip_unsupported.load(Ordering::Relaxed);
|
||||||
|
let total_us = self.preread_win_total_us.load(Ordering::Relaxed);
|
||||||
|
let mean_us = if win > 0 { total_us / win } else { 0 };
|
||||||
|
tracing::info!(
|
||||||
|
"connect_data preread: {} win / {} loss / {} skip(port) / {} skip(unsup), mean win time {}µs (ceiling {}µs)",
|
||||||
|
win,
|
||||||
|
loss,
|
||||||
|
skip_port,
|
||||||
|
skip_unsup,
|
||||||
|
mean_us,
|
||||||
|
CLIENT_FIRST_DATA_WAIT.as_micros(),
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn mux_loop(
|
async fn mux_loop(mut rx: mpsc::Receiver<MuxMsg>, fronter: Arc<DomainFronter>) {
|
||||||
mut rx: mpsc::Receiver<MuxMsg>,
|
|
||||||
fronter: Arc<DomainFronter>,
|
|
||||||
) {
|
|
||||||
// One semaphore per deployment ID, each allowing 30 concurrent requests.
|
// One semaphore per deployment ID, each allowing 30 concurrent requests.
|
||||||
let sems: Arc<HashMap<String, Arc<Semaphore>>> = Arc::new(
|
let sems: Arc<HashMap<String, Arc<Semaphore>>> = Arc::new(
|
||||||
fronter
|
fronter
|
||||||
@@ -107,7 +226,7 @@ async fn mux_loop(
|
|||||||
msgs.push(msg);
|
msgs.push(msg);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Split: connects go parallel, data/close get batched.
|
// Split: plain connects go parallel, data-bearing ops get batched.
|
||||||
let mut data_ops: Vec<BatchOp> = Vec::new();
|
let mut data_ops: Vec<BatchOp> = Vec::new();
|
||||||
let mut data_replies: Vec<(usize, oneshot::Sender<Result<TunnelResponse, String>>)> =
|
let mut data_replies: Vec<(usize, oneshot::Sender<Result<TunnelResponse, String>>)> =
|
||||||
Vec::new();
|
Vec::new();
|
||||||
@@ -128,6 +247,35 @@ async fn mux_loop(
|
|||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
MuxMsg::ConnectData { host, port, data, reply } => {
|
||||||
|
let encoded = Some(B64.encode(data.as_slice()));
|
||||||
|
let op_bytes = encoded.as_ref().map(|s| s.len()).unwrap_or(0);
|
||||||
|
|
||||||
|
if !data_ops.is_empty()
|
||||||
|
&& (data_ops.len() >= MAX_BATCH_OPS
|
||||||
|
|| batch_payload_bytes + op_bytes > MAX_BATCH_PAYLOAD_BYTES)
|
||||||
|
{
|
||||||
|
fire_batch(
|
||||||
|
&sems,
|
||||||
|
&fronter,
|
||||||
|
std::mem::take(&mut data_ops),
|
||||||
|
std::mem::take(&mut data_replies),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
batch_payload_bytes = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let idx = data_ops.len();
|
||||||
|
data_ops.push(BatchOp {
|
||||||
|
op: "connect_data".into(),
|
||||||
|
sid: None,
|
||||||
|
host: Some(host),
|
||||||
|
port: Some(port),
|
||||||
|
d: encoded,
|
||||||
|
});
|
||||||
|
data_replies.push((idx, reply));
|
||||||
|
batch_payload_bytes += op_bytes;
|
||||||
|
}
|
||||||
MuxMsg::Data { sid, data, reply } => {
|
MuxMsg::Data { sid, data, reply } => {
|
||||||
let encoded = if data.is_empty() {
|
let encoded = if data.is_empty() {
|
||||||
None
|
None
|
||||||
@@ -219,7 +367,12 @@ async fn fire_batch(
|
|||||||
f.tunnel_batch_request_to(&script_id, &data_ops),
|
f.tunnel_batch_request_to(&script_id, &data_ops),
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
tracing::info!("batch: {} ops → {}, rtt={:?}", n_ops, &script_id[..script_id.len().min(8)], t0.elapsed());
|
tracing::info!(
|
||||||
|
"batch: {} ops → {}, rtt={:?}",
|
||||||
|
n_ops,
|
||||||
|
&script_id[..script_id.len().min(8)],
|
||||||
|
t0.elapsed()
|
||||||
|
);
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(Ok(batch_resp)) => {
|
Ok(Ok(batch_resp)) => {
|
||||||
@@ -258,6 +411,95 @@ pub async fn tunnel_connection(
|
|||||||
port: u16,
|
port: u16,
|
||||||
mux: &Arc<TunnelMux>,
|
mux: &Arc<TunnelMux>,
|
||||||
) -> std::io::Result<()> {
|
) -> std::io::Result<()> {
|
||||||
|
// Only try the bundled connect+data optimization when it's likely to
|
||||||
|
// pay off — client-speaks-first protocols (TLS on 443 et al.) — and
|
||||||
|
// only if the tunnel-node has already accepted `connect_data` at least
|
||||||
|
// once this process lifetime (or we haven't tried yet). Check the
|
||||||
|
// fallback cache first so `skip(unsup)` shadows `skip(port)` in the
|
||||||
|
// metrics once the feature is disabled process-wide.
|
||||||
|
let initial_data = if mux.connect_data_unsupported() {
|
||||||
|
mux.record_preread_skip_unsupported(port);
|
||||||
|
None
|
||||||
|
} else if is_server_speaks_first(port) {
|
||||||
|
mux.record_preread_skip_port(port);
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mut buf = vec![0u8; 65536];
|
||||||
|
let t0 = Instant::now();
|
||||||
|
match tokio::time::timeout(CLIENT_FIRST_DATA_WAIT, sock.read(&mut buf)).await {
|
||||||
|
Ok(Ok(0)) => return Ok(()),
|
||||||
|
Ok(Ok(n)) => {
|
||||||
|
mux.record_preread_win(port, t0.elapsed());
|
||||||
|
buf.truncate(n);
|
||||||
|
Some(Arc::new(buf))
|
||||||
|
}
|
||||||
|
Ok(Err(e)) => return Err(e),
|
||||||
|
Err(_) => {
|
||||||
|
mux.record_preread_loss(port);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (sid, first_resp, pending_client_data) = match initial_data {
|
||||||
|
Some(data) => match connect_with_initial_data(host, port, data.clone(), mux).await? {
|
||||||
|
ConnectDataOutcome::Opened { sid, response } => (sid, Some(response), None),
|
||||||
|
ConnectDataOutcome::Unsupported => {
|
||||||
|
mux.mark_connect_data_unsupported();
|
||||||
|
let sid = connect_plain(host, port, mux).await?;
|
||||||
|
// Recover the buffered ClientHello from the Arc so the
|
||||||
|
// first tunnel_loop iteration can replay it. The mux task
|
||||||
|
// may still hold the other ref during the unsupported
|
||||||
|
// reply's settle window — fall back to a clone in that
|
||||||
|
// race (rare; the reply path drops its ref before we
|
||||||
|
// reach here in practice).
|
||||||
|
let bytes = Arc::try_unwrap(data).unwrap_or_else(|a| (*a).clone());
|
||||||
|
(sid, None, Some(bytes))
|
||||||
|
}
|
||||||
|
},
|
||||||
|
None => (connect_plain(host, port, mux).await?, None, None),
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::info!("tunnel session {} opened for {}:{}", sid, host, port);
|
||||||
|
|
||||||
|
// Run the first-response write + tunnel_loop inside an async block so
|
||||||
|
// any io-error propagates via `?` without bypassing the Close below.
|
||||||
|
// We deliberately don't use a Drop guard for Close: a Drop impl can't
|
||||||
|
// .await cleanly, and tokio::spawn from inside Drop is unreliable
|
||||||
|
// during runtime shutdown. The explicit send below covers every
|
||||||
|
// non-panic path; a panic during tunnel_loop would leak the session
|
||||||
|
// on the tunnel-node until its 5-minute idle reaper runs.
|
||||||
|
let result = async {
|
||||||
|
if let Some(resp) = first_resp {
|
||||||
|
match write_tunnel_response(&mut sock, &resp).await? {
|
||||||
|
WriteOutcome::Wrote | WriteOutcome::NoData => {}
|
||||||
|
WriteOutcome::BadBase64 => {
|
||||||
|
tracing::error!("tunnel session {}: bad base64 in connect_data response", sid);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if resp.eof.unwrap_or(false) {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tunnel_loop(&mut sock, &sid, mux, pending_client_data).await
|
||||||
|
}
|
||||||
|
.await;
|
||||||
|
|
||||||
|
mux.send(MuxMsg::Close { sid: sid.clone() }).await;
|
||||||
|
tracing::info!("tunnel session {} closed for {}:{}", sid, host, port);
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
enum ConnectDataOutcome {
|
||||||
|
Opened {
|
||||||
|
sid: String,
|
||||||
|
response: TunnelResponse,
|
||||||
|
},
|
||||||
|
Unsupported,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_plain(host: &str, port: u16, mux: &Arc<TunnelMux>) -> std::io::Result<String> {
|
||||||
let (reply_tx, reply_rx) = oneshot::channel();
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
mux.send(MuxMsg::Connect {
|
mux.send(MuxMsg::Connect {
|
||||||
host: host.to_string(),
|
host: host.to_string(),
|
||||||
@@ -266,7 +508,7 @@ pub async fn tunnel_connection(
|
|||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let sid = match reply_rx.await {
|
match reply_rx.await {
|
||||||
Ok(Ok(resp)) => {
|
Ok(Ok(resp)) => {
|
||||||
if let Some(ref e) = resp.e {
|
if let Some(ref e) = resp.e {
|
||||||
tracing::error!("tunnel connect error for {}:{}: {}", host, port, e);
|
tracing::error!("tunnel connect error for {}:{}: {}", host, port, e);
|
||||||
@@ -277,10 +519,45 @@ pub async fn tunnel_connection(
|
|||||||
}
|
}
|
||||||
resp.sid.ok_or_else(|| {
|
resp.sid.ok_or_else(|| {
|
||||||
std::io::Error::new(std::io::ErrorKind::Other, "tunnel connect: no session id")
|
std::io::Error::new(std::io::ErrorKind::Other, "tunnel connect: no session id")
|
||||||
})?
|
})
|
||||||
}
|
}
|
||||||
Ok(Err(e)) => {
|
Ok(Err(e)) => {
|
||||||
tracing::error!("tunnel connect error for {}:{}: {}", host, port, e);
|
tracing::error!("tunnel connect error for {}:{}: {}", host, port, e);
|
||||||
|
Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::ConnectionRefused,
|
||||||
|
e,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
Err(_) => Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::Other,
|
||||||
|
"mux channel closed",
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn connect_with_initial_data(
|
||||||
|
host: &str,
|
||||||
|
port: u16,
|
||||||
|
data: Arc<Vec<u8>>,
|
||||||
|
mux: &Arc<TunnelMux>,
|
||||||
|
) -> std::io::Result<ConnectDataOutcome> {
|
||||||
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
|
mux.send(MuxMsg::ConnectData {
|
||||||
|
host: host.to_string(),
|
||||||
|
port,
|
||||||
|
data,
|
||||||
|
reply: reply_tx,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let resp = match reply_rx.await {
|
||||||
|
Ok(Ok(resp)) => resp,
|
||||||
|
Ok(Err(e)) => {
|
||||||
|
if is_connect_data_unsupported_error_str(&e) {
|
||||||
|
tracing::debug!("connect_data unsupported for {}:{}: {}", host, port, e);
|
||||||
|
return Ok(ConnectDataOutcome::Unsupported);
|
||||||
|
}
|
||||||
|
tracing::error!("tunnel connect_data error for {}:{}: {}", host, port, e);
|
||||||
return Err(std::io::Error::new(
|
return Err(std::io::Error::new(
|
||||||
std::io::ErrorKind::ConnectionRefused,
|
std::io::ErrorKind::ConnectionRefused,
|
||||||
e,
|
e,
|
||||||
@@ -294,38 +571,96 @@ pub async fn tunnel_connection(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
tracing::info!("tunnel session {} opened for {}:{}", sid, host, port);
|
if is_connect_data_unsupported_response(&resp) {
|
||||||
let result = tunnel_loop(&mut sock, &sid, mux).await;
|
tracing::debug!(
|
||||||
mux.send(MuxMsg::Close { sid: sid.clone() }).await;
|
"connect_data unsupported for {}:{}: {:?}",
|
||||||
tracing::info!("tunnel session {} closed for {}:{}", sid, host, port);
|
host,
|
||||||
result
|
port,
|
||||||
|
resp.e
|
||||||
|
);
|
||||||
|
return Ok(ConnectDataOutcome::Unsupported);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(ref e) = resp.e {
|
||||||
|
tracing::error!("tunnel connect_data error for {}:{}: {}", host, port, e);
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::ConnectionRefused,
|
||||||
|
e.clone(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(sid) = resp.sid.clone() else {
|
||||||
|
return Err(std::io::Error::new(
|
||||||
|
std::io::ErrorKind::Other,
|
||||||
|
"tunnel connect_data: no session id",
|
||||||
|
));
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(ConnectDataOutcome::Opened { sid, response: resp })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decide whether a response indicates the tunnel-node (or apps_script
|
||||||
|
/// layer in front of it) didn't recognize `connect_data`.
|
||||||
|
///
|
||||||
|
/// Primary signal: the structured `code` field (`UNSUPPORTED_OP`), emitted
|
||||||
|
/// by any tunnel-node or apps_script deployment that has this change.
|
||||||
|
/// Fallback signal (for legacy deployments, pre-connect_data): substring
|
||||||
|
/// match on the stable error string. The string-match is a one-way
|
||||||
|
/// compatibility hatch — newer deployments set `code` so future refactors
|
||||||
|
/// of the error text won't silently break detection.
|
||||||
|
///
|
||||||
|
/// Two error shapes are possible on the legacy path:
|
||||||
|
/// * tunnel-node's single-op/batch handler: `"unknown op: connect_data"`
|
||||||
|
/// * apps_script's `_doTunnel` default branch: `"unknown tunnel op: connect_data"`
|
||||||
|
///
|
||||||
|
/// Apps_script and tunnel-node ship on independent cadences, so it is
|
||||||
|
/// realistic for a user to upgrade one but not the other — detection has
|
||||||
|
/// to cover both shapes or the feature hard-fails on version skew.
|
||||||
|
fn is_connect_data_unsupported_response(resp: &TunnelResponse) -> bool {
|
||||||
|
if resp.code.as_deref() == Some(CODE_UNSUPPORTED_OP) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
resp.e
|
||||||
|
.as_deref()
|
||||||
|
.map(is_connect_data_unsupported_error_str)
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_connect_data_unsupported_error_str(e: &str) -> bool {
|
||||||
|
let e = e.to_ascii_lowercase();
|
||||||
|
(e.contains("unknown op") || e.contains("unknown tunnel op")) && e.contains("connect_data")
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn tunnel_loop(
|
async fn tunnel_loop(
|
||||||
sock: &mut TcpStream,
|
sock: &mut TcpStream,
|
||||||
sid: &str,
|
sid: &str,
|
||||||
mux: &Arc<TunnelMux>,
|
mux: &Arc<TunnelMux>,
|
||||||
|
mut pending_client_data: Option<Vec<u8>>,
|
||||||
) -> std::io::Result<()> {
|
) -> std::io::Result<()> {
|
||||||
let (mut reader, mut writer) = sock.split();
|
let (mut reader, mut writer) = sock.split();
|
||||||
let mut buf = vec![0u8; 65536];
|
let mut buf = vec![0u8; 65536];
|
||||||
let mut consecutive_empty = 0u32;
|
let mut consecutive_empty = 0u32;
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let read_timeout = match consecutive_empty {
|
let client_data = if let Some(data) = pending_client_data.take() {
|
||||||
0 => Duration::from_millis(20),
|
Some(data)
|
||||||
1 => Duration::from_millis(80),
|
} else {
|
||||||
2 => Duration::from_millis(200),
|
let read_timeout = match consecutive_empty {
|
||||||
_ => Duration::from_secs(30),
|
0 => Duration::from_millis(20),
|
||||||
};
|
1 => Duration::from_millis(80),
|
||||||
|
2 => Duration::from_millis(200),
|
||||||
|
_ => Duration::from_secs(30),
|
||||||
|
};
|
||||||
|
|
||||||
let client_data = match tokio::time::timeout(read_timeout, reader.read(&mut buf)).await {
|
match tokio::time::timeout(read_timeout, reader.read(&mut buf)).await {
|
||||||
Ok(Ok(0)) => break,
|
Ok(Ok(0)) => break,
|
||||||
Ok(Ok(n)) => {
|
Ok(Ok(n)) => {
|
||||||
consecutive_empty = 0;
|
consecutive_empty = 0;
|
||||||
Some(buf[..n].to_vec())
|
Some(buf[..n].to_vec())
|
||||||
|
}
|
||||||
|
Ok(Err(_)) => break,
|
||||||
|
Err(_) => None,
|
||||||
}
|
}
|
||||||
Ok(Err(_)) => break,
|
|
||||||
Err(_) => None,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if client_data.is_none() && consecutive_empty > 3 {
|
if client_data.is_none() && consecutive_empty > 3 {
|
||||||
@@ -364,25 +699,15 @@ async fn tunnel_loop(
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
let got_data = if let Some(ref d) = resp.d {
|
let got_data = match write_tunnel_response(&mut writer, &resp).await? {
|
||||||
if !d.is_empty() {
|
WriteOutcome::Wrote => true,
|
||||||
match B64.decode(d) {
|
WriteOutcome::NoData => false,
|
||||||
Ok(bytes) if !bytes.is_empty() => {
|
WriteOutcome::BadBase64 => {
|
||||||
writer.write_all(&bytes).await?;
|
// Tunnel-node gave us garbage; tear the session down but
|
||||||
writer.flush().await?;
|
// do NOT propagate as an io error — the caller's Close
|
||||||
true
|
// guard will clean up on the tunnel-node side.
|
||||||
}
|
break;
|
||||||
Err(e) => {
|
|
||||||
tracing::error!("tunnel bad base64: {}", e);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
_ => false,
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
}
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if resp.eof.unwrap_or(false) {
|
if resp.eof.unwrap_or(false) {
|
||||||
@@ -398,3 +723,225 @@ async fn tunnel_loop(
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum WriteOutcome {
|
||||||
|
Wrote,
|
||||||
|
NoData,
|
||||||
|
BadBase64,
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn write_tunnel_response<W>(
|
||||||
|
writer: &mut W,
|
||||||
|
resp: &TunnelResponse,
|
||||||
|
) -> std::io::Result<WriteOutcome>
|
||||||
|
where
|
||||||
|
W: AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
let Some(ref d) = resp.d else {
|
||||||
|
return Ok(WriteOutcome::NoData);
|
||||||
|
};
|
||||||
|
if d.is_empty() {
|
||||||
|
return Ok(WriteOutcome::NoData);
|
||||||
|
}
|
||||||
|
|
||||||
|
match B64.decode(d) {
|
||||||
|
Ok(bytes) if !bytes.is_empty() => {
|
||||||
|
writer.write_all(&bytes).await?;
|
||||||
|
writer.flush().await?;
|
||||||
|
Ok(WriteOutcome::Wrote)
|
||||||
|
}
|
||||||
|
Ok(_) => Ok(WriteOutcome::NoData),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("tunnel bad base64: {}", e);
|
||||||
|
Ok(WriteOutcome::BadBase64)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn resp_with(code: Option<&str>, e: Option<&str>) -> TunnelResponse {
|
||||||
|
TunnelResponse {
|
||||||
|
sid: None,
|
||||||
|
d: None,
|
||||||
|
eof: None,
|
||||||
|
e: e.map(str::to_string),
|
||||||
|
code: code.map(str::to_string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unsupported_detection_via_structured_code() {
|
||||||
|
assert!(is_connect_data_unsupported_response(&resp_with(Some("UNSUPPORTED_OP"), None)));
|
||||||
|
assert!(is_connect_data_unsupported_response(&resp_with(
|
||||||
|
Some("UNSUPPORTED_OP"),
|
||||||
|
Some("unknown op: connect_data"),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unsupported_detection_via_legacy_tunnel_node_string() {
|
||||||
|
// Pre-change tunnel-node: no code field, bare "unknown op: ...".
|
||||||
|
assert!(is_connect_data_unsupported_response(&resp_with(
|
||||||
|
None, Some("unknown op: connect_data"),
|
||||||
|
)));
|
||||||
|
assert!(is_connect_data_unsupported_response(&resp_with(
|
||||||
|
None, Some("Unknown Op: CONNECT_DATA"),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unsupported_detection_via_legacy_apps_script_string() {
|
||||||
|
// Pre-change apps_script: default branch emits "unknown tunnel op: ...".
|
||||||
|
// This is the realistic skew case — user upgrades tunnel-node + client
|
||||||
|
// binary but hasn't redeployed the Apps Script yet.
|
||||||
|
assert!(is_connect_data_unsupported_response(&resp_with(
|
||||||
|
None, Some("unknown tunnel op: connect_data"),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unsupported_detection_rejects_unrelated_errors() {
|
||||||
|
assert!(!is_connect_data_unsupported_response(&resp_with(
|
||||||
|
None, Some("connect failed: refused"),
|
||||||
|
)));
|
||||||
|
assert!(!is_connect_data_unsupported_response(&resp_with(None, Some("bad base64"))));
|
||||||
|
assert!(!is_connect_data_unsupported_response(&resp_with(None, None)));
|
||||||
|
// "connect_data" alone (without "unknown op") shouldn't trigger.
|
||||||
|
assert!(!is_connect_data_unsupported_response(&resp_with(
|
||||||
|
None, Some("connect_data: bad port"),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn server_speaks_first_covers_common_protocols() {
|
||||||
|
for p in [21u16, 22, 25, 80, 110, 143, 587] {
|
||||||
|
assert!(is_server_speaks_first(p), "port {} should be server-first", p);
|
||||||
|
}
|
||||||
|
for p in [443u16, 8443, 853, 993, 1234] {
|
||||||
|
assert!(!is_server_speaks_first(p), "port {} should NOT be server-first", p);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a TunnelMux whose send channel is exposed to the test rather
|
||||||
|
/// than wired to a real DomainFronter. Lets tests assert what messages
|
||||||
|
/// the client would emit without needing network or apps_script.
|
||||||
|
fn mux_for_test() -> (Arc<TunnelMux>, mpsc::Receiver<MuxMsg>) {
|
||||||
|
let (tx, rx) = mpsc::channel(16);
|
||||||
|
let mux = Arc::new(TunnelMux {
|
||||||
|
tx,
|
||||||
|
connect_data_unsupported: Arc::new(AtomicBool::new(false)),
|
||||||
|
preread_win: AtomicU64::new(0),
|
||||||
|
preread_loss: AtomicU64::new(0),
|
||||||
|
preread_skip_port: AtomicU64::new(0),
|
||||||
|
preread_skip_unsupported: AtomicU64::new(0),
|
||||||
|
preread_win_total_us: AtomicU64::new(0),
|
||||||
|
preread_total_events: AtomicU64::new(0),
|
||||||
|
});
|
||||||
|
(mux, rx)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The buffered ClientHello from the pre-read window must reach the
|
||||||
|
/// tunnel-node as the first `Data` op on the fallback path. If this
|
||||||
|
/// regresses, every TLS handshake stalls until the 30 s read-timeout
|
||||||
|
/// fires — catastrophic and silent without a test.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn tunnel_loop_replays_pending_client_data_before_reading_socket() {
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
// Set up a loopback pair so tunnel_loop has a real TcpStream to
|
||||||
|
// work with. We never write to its peer, so tunnel_loop's "read
|
||||||
|
// from client" branch would block indefinitely — meaning any
|
||||||
|
// `Data` msg it emits must have come from pending_client_data.
|
||||||
|
let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
|
||||||
|
let addr = listener.local_addr().unwrap();
|
||||||
|
let accept = tokio::spawn(async move { listener.accept().await.unwrap().0 });
|
||||||
|
let _client = TcpStream::connect(addr).await.unwrap();
|
||||||
|
let mut server_side = accept.await.unwrap();
|
||||||
|
|
||||||
|
let (mux, mut rx) = mux_for_test();
|
||||||
|
let pending = Some(b"CLIENTHELLO".to_vec());
|
||||||
|
|
||||||
|
let loop_handle = tokio::spawn({
|
||||||
|
let mux = mux.clone();
|
||||||
|
async move {
|
||||||
|
tunnel_loop(&mut server_side, "sid-under-test", &mux, pending).await
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// The first message tunnel_loop emits must be Data carrying the
|
||||||
|
// replayed bytes — NOT whatever it would have read from the socket.
|
||||||
|
let msg = tokio::time::timeout(Duration::from_secs(2), rx.recv())
|
||||||
|
.await
|
||||||
|
.expect("tunnel_loop did not send a message within 2s")
|
||||||
|
.expect("mux channel closed unexpectedly");
|
||||||
|
|
||||||
|
match msg {
|
||||||
|
MuxMsg::Data { sid, data, reply } => {
|
||||||
|
assert_eq!(sid, "sid-under-test");
|
||||||
|
assert_eq!(&data[..], b"CLIENTHELLO");
|
||||||
|
// Reply with eof so tunnel_loop unwinds cleanly.
|
||||||
|
let _ = reply.send(Ok(TunnelResponse {
|
||||||
|
sid: Some("sid-under-test".into()),
|
||||||
|
d: None,
|
||||||
|
eof: Some(true),
|
||||||
|
e: None,
|
||||||
|
code: None,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
other => panic!(
|
||||||
|
"first mux message was not Data (expected replay); got {:?}",
|
||||||
|
match other {
|
||||||
|
MuxMsg::Connect { .. } => "Connect",
|
||||||
|
MuxMsg::ConnectData { .. } => "ConnectData",
|
||||||
|
MuxMsg::Data { .. } => unreachable!(),
|
||||||
|
MuxMsg::Close { .. } => "Close",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = tokio::time::timeout(Duration::from_secs(2), loop_handle)
|
||||||
|
.await
|
||||||
|
.expect("tunnel_loop did not exit after eof");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Once `mark_connect_data_unsupported` is called, future sessions
|
||||||
|
/// must see the flag — no per-session repeat of the detect-and-fallback
|
||||||
|
/// cost. If this regresses, every new flow pays an extra round trip
|
||||||
|
/// against a tunnel-node that will never learn the new op.
|
||||||
|
#[test]
|
||||||
|
fn unsupported_cache_is_sticky() {
|
||||||
|
let (mux, _rx) = mux_for_test();
|
||||||
|
assert!(!mux.connect_data_unsupported());
|
||||||
|
mux.mark_connect_data_unsupported();
|
||||||
|
assert!(mux.connect_data_unsupported());
|
||||||
|
mux.mark_connect_data_unsupported(); // idempotent
|
||||||
|
assert!(mux.connect_data_unsupported());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preread_counters_track_each_outcome() {
|
||||||
|
let (mux, _rx) = mux_for_test();
|
||||||
|
|
||||||
|
mux.record_preread_win(443, Duration::from_micros(3_500));
|
||||||
|
mux.record_preread_win(443, Duration::from_micros(1_500));
|
||||||
|
mux.record_preread_loss(443);
|
||||||
|
mux.record_preread_skip_port(80);
|
||||||
|
mux.record_preread_skip_unsupported(443);
|
||||||
|
|
||||||
|
assert_eq!(mux.preread_win.load(Ordering::Relaxed), 2);
|
||||||
|
assert_eq!(mux.preread_loss.load(Ordering::Relaxed), 1);
|
||||||
|
assert_eq!(mux.preread_skip_port.load(Ordering::Relaxed), 1);
|
||||||
|
assert_eq!(mux.preread_skip_unsupported.load(Ordering::Relaxed), 1);
|
||||||
|
// Two wins summing to 5000 µs.
|
||||||
|
assert_eq!(mux.preread_win_total_us.load(Ordering::Relaxed), 5_000);
|
||||||
|
// Five record_* calls, so trigger counter is at 5.
|
||||||
|
assert_eq!(mux.preread_total_events.load(Ordering::Relaxed), 5);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
+299
-18
@@ -26,6 +26,12 @@ use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|||||||
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||||
use tokio::net::TcpStream;
|
use tokio::net::TcpStream;
|
||||||
use tokio::sync::Mutex;
|
use tokio::sync::Mutex;
|
||||||
|
use tokio::task::JoinSet;
|
||||||
|
|
||||||
|
/// Structured error code returned when the tunnel-node receives an op it
|
||||||
|
/// doesn't recognize. Clients use this (rather than string-matching `e`) to
|
||||||
|
/// detect a version mismatch and gracefully fall back.
|
||||||
|
const CODE_UNSUPPORTED_OP: &str = "UNSUPPORTED_OP";
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
// Session
|
// Session
|
||||||
@@ -137,17 +143,25 @@ struct TunnelRequest {
|
|||||||
#[serde(default)] data: Option<String>,
|
#[serde(default)] data: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Serialize, Clone)]
|
#[derive(Serialize, Clone, Debug)]
|
||||||
struct TunnelResponse {
|
struct TunnelResponse {
|
||||||
#[serde(skip_serializing_if = "Option::is_none")] sid: Option<String>,
|
#[serde(skip_serializing_if = "Option::is_none")] sid: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")] d: Option<String>,
|
#[serde(skip_serializing_if = "Option::is_none")] d: Option<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")] eof: Option<bool>,
|
#[serde(skip_serializing_if = "Option::is_none")] eof: Option<bool>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")] e: Option<String>,
|
#[serde(skip_serializing_if = "Option::is_none")] e: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")] code: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TunnelResponse {
|
impl TunnelResponse {
|
||||||
fn error(msg: impl Into<String>) -> Self {
|
fn error(msg: impl Into<String>) -> Self {
|
||||||
Self { sid: None, d: None, eof: None, e: Some(msg.into()) }
|
Self { sid: None, d: None, eof: None, e: Some(msg.into()), code: None }
|
||||||
|
}
|
||||||
|
fn unsupported_op(op: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
sid: None, d: None, eof: None,
|
||||||
|
e: Some(format!("unknown op: {}", op)),
|
||||||
|
code: Some(CODE_UNSUPPORTED_OP.into()),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -188,9 +202,12 @@ async fn handle_tunnel(
|
|||||||
}
|
}
|
||||||
match req.op.as_str() {
|
match req.op.as_str() {
|
||||||
"connect" => Json(handle_connect(&state, req.host, req.port).await),
|
"connect" => Json(handle_connect(&state, req.host, req.port).await),
|
||||||
|
"connect_data" => {
|
||||||
|
Json(handle_connect_data_single(&state, req.host, req.port, req.data).await)
|
||||||
|
}
|
||||||
"data" => Json(handle_data_single(&state, req.sid, req.data).await),
|
"data" => Json(handle_data_single(&state, req.sid, req.data).await),
|
||||||
"close" => Json(handle_close(&state, req.sid).await),
|
"close" => Json(handle_close(&state, req.sid).await),
|
||||||
other => Json(TunnelResponse::error(format!("unknown op: {}", other))),
|
other => Json(TunnelResponse::unsupported_op(other)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -238,15 +255,51 @@ async fn handle_batch(
|
|||||||
// then do a short sleep to let servers respond, then drain all.
|
// then do a short sleep to let servers respond, then drain all.
|
||||||
// This batches the network round trips on the server side too.
|
// This batches the network round trips on the server side too.
|
||||||
|
|
||||||
// Phase 1: process connects and writes
|
// Phase 1: process connects and writes.
|
||||||
|
//
|
||||||
|
// `connect` and `connect_data` ops each establish a brand-new upstream
|
||||||
|
// TCP connection which can take up to 10 s (create_session timeout).
|
||||||
|
// Running them inline head-of-line-blocks every other op in the batch,
|
||||||
|
// so we dispatch both into a JoinSet and await them concurrently below.
|
||||||
|
//
|
||||||
|
// `connect_data` is expected to dominate in practice (new client) but
|
||||||
|
// we still hit `connect` from older clients or from server-speaks-first
|
||||||
|
// ports that skip the pre-read — if a slow `connect` landed in the same
|
||||||
|
// batch as data-bearing ops it could stall everyone.
|
||||||
let mut results: Vec<(usize, TunnelResponse)> = Vec::with_capacity(req.ops.len());
|
let mut results: Vec<(usize, TunnelResponse)> = Vec::with_capacity(req.ops.len());
|
||||||
let mut data_ops: Vec<(usize, String)> = Vec::new(); // (index, sid) for data ops needing drain
|
let mut data_ops: Vec<(usize, String)> = Vec::new(); // (index, sid) for data ops needing drain
|
||||||
|
|
||||||
|
enum NewConn {
|
||||||
|
Connect(TunnelResponse),
|
||||||
|
ConnectData(Result<String, TunnelResponse>),
|
||||||
|
}
|
||||||
|
let mut new_conn_jobs: JoinSet<(usize, NewConn)> = JoinSet::new();
|
||||||
|
|
||||||
for (i, op) in req.ops.iter().enumerate() {
|
for (i, op) in req.ops.iter().enumerate() {
|
||||||
match op.op.as_str() {
|
match op.op.as_str() {
|
||||||
"connect" => {
|
"connect" => {
|
||||||
let r = handle_connect(&state, op.host.clone(), op.port).await;
|
let state = state.clone();
|
||||||
results.push((i, r));
|
let host = op.host.clone();
|
||||||
|
let port = op.port;
|
||||||
|
new_conn_jobs.spawn(async move {
|
||||||
|
(i, NewConn::Connect(handle_connect(&state, host, port).await))
|
||||||
|
});
|
||||||
|
}
|
||||||
|
"connect_data" => {
|
||||||
|
let state = state.clone();
|
||||||
|
let host = op.host.clone();
|
||||||
|
let port = op.port;
|
||||||
|
let d = op.d.clone();
|
||||||
|
new_conn_jobs.spawn(async move {
|
||||||
|
// Drop the returned Arc<SessionInner>: phase 2 below
|
||||||
|
// holds the sessions-map lock once for the whole batch
|
||||||
|
// and re-looks up each sid, which is cheap. The Arc
|
||||||
|
// return is a convenience for the single-op path only.
|
||||||
|
let r = handle_connect_data_phase1(&state, host, port, d)
|
||||||
|
.await
|
||||||
|
.map(|(sid, _inner)| sid);
|
||||||
|
(i, NewConn::ConnectData(r))
|
||||||
|
});
|
||||||
}
|
}
|
||||||
"data" => {
|
"data" => {
|
||||||
let sid = match &op.sid {
|
let sid = match &op.sid {
|
||||||
@@ -273,7 +326,9 @@ async fn handle_batch(
|
|||||||
data_ops.push((i, sid));
|
data_ops.push((i, sid));
|
||||||
} else {
|
} else {
|
||||||
drop(sessions);
|
drop(sessions);
|
||||||
results.push((i, TunnelResponse { sid: Some(sid), d: None, eof: Some(true), e: None }));
|
results.push((i, TunnelResponse {
|
||||||
|
sid: Some(sid), d: None, eof: Some(true), e: None, code: None,
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"close" => {
|
"close" => {
|
||||||
@@ -281,7 +336,21 @@ async fn handle_batch(
|
|||||||
results.push((i, r));
|
results.push((i, r));
|
||||||
}
|
}
|
||||||
other => {
|
other => {
|
||||||
results.push((i, TunnelResponse::error(format!("unknown op: {}", other))));
|
results.push((i, TunnelResponse::unsupported_op(other)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Await all concurrent connect / connect_data jobs. For connect_data,
|
||||||
|
// successful ones join the data-drain set in phase 2; plain connects
|
||||||
|
// go straight to results because they have no initial data to drain.
|
||||||
|
while let Some(join) = new_conn_jobs.join_next().await {
|
||||||
|
match join {
|
||||||
|
Ok((i, NewConn::Connect(r))) => results.push((i, r)),
|
||||||
|
Ok((i, NewConn::ConnectData(Ok(sid)))) => data_ops.push((i, sid)),
|
||||||
|
Ok((i, NewConn::ConnectData(Err(r)))) => results.push((i, r)),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("new-connection task panicked: {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -304,12 +373,12 @@ async fn handle_batch(
|
|||||||
results.push((*i, TunnelResponse {
|
results.push((*i, TunnelResponse {
|
||||||
sid: Some(sid.clone()),
|
sid: Some(sid.clone()),
|
||||||
d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
|
d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
|
||||||
eof: Some(eof), e: None,
|
eof: Some(eof), e: None, code: None,
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
results.push((*i, TunnelResponse {
|
results.push((*i, TunnelResponse {
|
||||||
sid: Some(sid.clone()), d: None, eof: Some(true), e: None,
|
sid: Some(sid.clone()), d: None, eof: Some(true), e: None, code: None,
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -325,11 +394,11 @@ async fn handle_batch(
|
|||||||
results.push((*i, TunnelResponse {
|
results.push((*i, TunnelResponse {
|
||||||
sid: Some(sid.clone()),
|
sid: Some(sid.clone()),
|
||||||
d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
|
d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
|
||||||
eof: Some(eof), e: None,
|
eof: Some(eof), e: None, code: None,
|
||||||
}));
|
}));
|
||||||
} else {
|
} else {
|
||||||
results.push((*i, TunnelResponse {
|
results.push((*i, TunnelResponse {
|
||||||
sid: Some(sid.clone()), d: None, eof: Some(true), e: None,
|
sid: Some(sid.clone()), d: None, eof: Some(true), e: None, code: None,
|
||||||
}));
|
}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -372,14 +441,25 @@ fn decompress_gzip(data: &[u8]) -> Result<Vec<u8>, String> {
|
|||||||
// Shared op handlers
|
// Shared op handlers
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
async fn handle_connect(state: &AppState, host: Option<String>, port: Option<u16>) -> TunnelResponse {
|
fn validate_host_port(
|
||||||
|
host: Option<String>,
|
||||||
|
port: Option<u16>,
|
||||||
|
) -> Result<(String, u16), TunnelResponse> {
|
||||||
let host = match host {
|
let host = match host {
|
||||||
Some(h) if !h.is_empty() => h,
|
Some(h) if !h.is_empty() => h,
|
||||||
_ => return TunnelResponse::error("missing host"),
|
_ => return Err(TunnelResponse::error("missing host")),
|
||||||
};
|
};
|
||||||
let port = match port {
|
let port = match port {
|
||||||
Some(p) if p > 0 => p,
|
Some(p) if p > 0 => p,
|
||||||
_ => return TunnelResponse::error("missing or invalid port"),
|
_ => return Err(TunnelResponse::error("missing or invalid port")),
|
||||||
|
};
|
||||||
|
Ok((host, port))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_connect(state: &AppState, host: Option<String>, port: Option<u16>) -> TunnelResponse {
|
||||||
|
let (host, port) = match validate_host_port(host, port) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(r) => return r,
|
||||||
};
|
};
|
||||||
let session = match create_session(&host, port).await {
|
let session = match create_session(&host, port).await {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
@@ -388,7 +468,82 @@ async fn handle_connect(state: &AppState, host: Option<String>, port: Option<u16
|
|||||||
let sid = uuid::Uuid::new_v4().to_string();
|
let sid = uuid::Uuid::new_v4().to_string();
|
||||||
tracing::info!("session {} -> {}:{}", sid, host, port);
|
tracing::info!("session {} -> {}:{}", sid, host, port);
|
||||||
state.sessions.lock().await.insert(sid.clone(), session);
|
state.sessions.lock().await.insert(sid.clone(), session);
|
||||||
TunnelResponse { sid: Some(sid), d: None, eof: Some(false), e: None }
|
TunnelResponse { sid: Some(sid), d: None, eof: Some(false), e: None, code: None }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Open a session and write the client's first bytes in one round trip.
|
||||||
|
/// Returns the new sid plus an `Arc<SessionInner>` so unary callers
|
||||||
|
/// (`handle_connect_data_single`) can drain the first response without a
|
||||||
|
/// second sessions-map lookup. The batch caller drops the Arc — it takes
|
||||||
|
/// a single lock across all drain-bound sessions in phase 2, which is
|
||||||
|
/// cheaper than the Arc plumbing would be.
|
||||||
|
async fn handle_connect_data_phase1(
|
||||||
|
state: &AppState,
|
||||||
|
host: Option<String>,
|
||||||
|
port: Option<u16>,
|
||||||
|
data: Option<String>,
|
||||||
|
) -> Result<(String, Arc<SessionInner>), TunnelResponse> {
|
||||||
|
let (host, port) = validate_host_port(host, port)?;
|
||||||
|
|
||||||
|
let session = create_session(&host, port)
|
||||||
|
.await
|
||||||
|
.map_err(|e| TunnelResponse::error(format!("connect failed: {}", e)))?;
|
||||||
|
|
||||||
|
// Any failure below this point must abort the reader task, otherwise
|
||||||
|
// the newly-opened upstream TCP connection would leak. Keep the
|
||||||
|
// abort paths explicit rather than burying them in `.map_err`.
|
||||||
|
if let Some(ref data_b64) = data {
|
||||||
|
if !data_b64.is_empty() {
|
||||||
|
let bytes = match B64.decode(data_b64) {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => {
|
||||||
|
session.reader_handle.abort();
|
||||||
|
return Err(TunnelResponse::error(format!("bad base64: {}", e)));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
if !bytes.is_empty() {
|
||||||
|
let mut w = session.inner.writer.lock().await;
|
||||||
|
if let Err(e) = w.write_all(&bytes).await {
|
||||||
|
drop(w);
|
||||||
|
session.reader_handle.abort();
|
||||||
|
return Err(TunnelResponse::error(format!("write failed: {}", e)));
|
||||||
|
}
|
||||||
|
let _ = w.flush().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let inner = session.inner.clone();
|
||||||
|
let sid = uuid::Uuid::new_v4().to_string();
|
||||||
|
tracing::info!("session {} -> {}:{} (connect_data)", sid, host, port);
|
||||||
|
state.sessions.lock().await.insert(sid.clone(), session);
|
||||||
|
Ok((sid, inner))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_connect_data_single(
|
||||||
|
state: &AppState,
|
||||||
|
host: Option<String>,
|
||||||
|
port: Option<u16>,
|
||||||
|
data: Option<String>,
|
||||||
|
) -> TunnelResponse {
|
||||||
|
let (sid, inner) = match handle_connect_data_phase1(state, host, port, data).await {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(r) => return r,
|
||||||
|
};
|
||||||
|
let (data, eof) = wait_and_drain(&inner, Duration::from_secs(5)).await;
|
||||||
|
if eof {
|
||||||
|
if let Some(s) = state.sessions.lock().await.remove(&sid) {
|
||||||
|
s.reader_handle.abort();
|
||||||
|
tracing::info!("session {} closed by remote", sid);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
TunnelResponse {
|
||||||
|
sid: Some(sid),
|
||||||
|
d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
|
||||||
|
eof: Some(eof),
|
||||||
|
e: None,
|
||||||
|
code: None,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<String>) -> TunnelResponse {
|
async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<String>) -> TunnelResponse {
|
||||||
@@ -428,7 +583,7 @@ async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<
|
|||||||
TunnelResponse {
|
TunnelResponse {
|
||||||
sid: Some(sid),
|
sid: Some(sid),
|
||||||
d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
|
d: if data.is_empty() { None } else { Some(B64.encode(&data)) },
|
||||||
eof: Some(eof), e: None,
|
eof: Some(eof), e: None, code: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -441,7 +596,7 @@ async fn handle_close(state: &AppState, sid: Option<String>) -> TunnelResponse {
|
|||||||
s.reader_handle.abort();
|
s.reader_handle.abort();
|
||||||
tracing::info!("session {} closed by client", sid);
|
tracing::info!("session {} closed by client", sid);
|
||||||
}
|
}
|
||||||
TunnelResponse { sid: Some(sid), d: None, eof: Some(true), e: None }
|
TunnelResponse { sid: Some(sid), d: None, eof: Some(true), e: None, code: None }
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
// ---------------------------------------------------------------------------
|
||||||
@@ -520,3 +675,129 @@ async fn main() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Tests
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
|
||||||
|
fn fresh_state() -> AppState {
|
||||||
|
AppState {
|
||||||
|
sessions: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
auth_key: "test-key".into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Spin up a one-shot TCP server that echoes everything it reads back
|
||||||
|
/// with a `"ECHO: "` prefix, then returns the bound port.
|
||||||
|
async fn start_echo_server() -> u16 {
|
||||||
|
let listener = TcpListener::bind(("127.0.0.1", 0)).await.unwrap();
|
||||||
|
let port = listener.local_addr().unwrap().port();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Ok((mut sock, _)) = listener.accept().await {
|
||||||
|
let mut buf = [0u8; 1024];
|
||||||
|
if let Ok(n) = sock.read(&mut buf).await {
|
||||||
|
let mut out = b"ECHO: ".to_vec();
|
||||||
|
out.extend_from_slice(&buf[..n]);
|
||||||
|
let _ = sock.write_all(&out).await;
|
||||||
|
let _ = sock.flush().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
port
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn unsupported_op_response_has_structured_code() {
|
||||||
|
let resp = TunnelResponse::unsupported_op("connect_data");
|
||||||
|
assert_eq!(resp.code.as_deref(), Some(CODE_UNSUPPORTED_OP));
|
||||||
|
assert_eq!(resp.e.as_deref(), Some("unknown op: connect_data"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn validate_host_port_rejects_empty_and_zero() {
|
||||||
|
assert!(validate_host_port(None, Some(443)).is_err());
|
||||||
|
assert!(validate_host_port(Some("".into()), Some(443)).is_err());
|
||||||
|
assert!(validate_host_port(Some("x".into()), None).is_err());
|
||||||
|
assert!(validate_host_port(Some("x".into()), Some(0)).is_err());
|
||||||
|
assert_eq!(
|
||||||
|
validate_host_port(Some("host".into()), Some(443)).unwrap(),
|
||||||
|
("host".to_string(), 443),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn connect_data_phase1_writes_initial_data_and_returns_inner() {
|
||||||
|
let port = start_echo_server().await;
|
||||||
|
let state = fresh_state();
|
||||||
|
|
||||||
|
let (sid, inner) = handle_connect_data_phase1(
|
||||||
|
&state,
|
||||||
|
Some("127.0.0.1".into()),
|
||||||
|
Some(port),
|
||||||
|
Some(B64.encode(b"hello")),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("phase1 should succeed");
|
||||||
|
|
||||||
|
// Session was inserted.
|
||||||
|
assert!(state.sessions.lock().await.contains_key(&sid));
|
||||||
|
|
||||||
|
// Echo server sent back "ECHO: hello". Use wait_and_drain on the
|
||||||
|
// returned Arc — no map re-lookup needed (this is the fix).
|
||||||
|
let (data, _eof) = wait_and_drain(&inner, Duration::from_secs(2)).await;
|
||||||
|
assert_eq!(&data[..], b"ECHO: hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn connect_data_single_bundles_connect_and_first_bytes() {
|
||||||
|
let port = start_echo_server().await;
|
||||||
|
let state = fresh_state();
|
||||||
|
|
||||||
|
let resp = handle_connect_data_single(
|
||||||
|
&state,
|
||||||
|
Some("127.0.0.1".into()),
|
||||||
|
Some(port),
|
||||||
|
Some(B64.encode(b"world")),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
assert!(resp.e.is_none(), "unexpected error: {:?}", resp.e);
|
||||||
|
assert!(resp.sid.is_some());
|
||||||
|
let decoded = B64.decode(resp.d.unwrap()).unwrap();
|
||||||
|
assert_eq!(&decoded[..], b"ECHO: world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn connect_data_rejects_missing_host() {
|
||||||
|
let state = fresh_state();
|
||||||
|
let resp = handle_connect_data_single(
|
||||||
|
&state, None, Some(443), Some(B64.encode(b"x")),
|
||||||
|
).await;
|
||||||
|
assert!(resp.e.as_deref().unwrap_or("").contains("missing host"));
|
||||||
|
assert!(state.sessions.lock().await.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn connect_data_rejects_bad_base64_and_does_not_leak_session() {
|
||||||
|
// Need a live target so we reach the base64-decode step after
|
||||||
|
// create_session succeeds — otherwise we'd fail earlier.
|
||||||
|
let port = start_echo_server().await;
|
||||||
|
let state = fresh_state();
|
||||||
|
let resp = handle_connect_data_single(
|
||||||
|
&state,
|
||||||
|
Some("127.0.0.1".into()),
|
||||||
|
Some(port),
|
||||||
|
Some("!!!not base64!!!".into()),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
assert!(resp.e.as_deref().unwrap_or("").contains("bad base64"));
|
||||||
|
// Session should NOT be in the map since phase1 rejected it.
|
||||||
|
assert!(state.sessions.lock().await.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user