fix(tunnel-node): batch drain correctness and lock contention (#695)

* fix(tunnel-node): batch drain correctness and lock contention

* fix(tunnel-node): single-op lock contention and batch base64 consistency
This commit is contained in:
dazzling-no-more
2026-05-04 04:33:49 +04:00
committed by GitHub
parent 3cb56c36c7
commit 38d9d9fcd6
+328 -116
View File
@@ -393,6 +393,27 @@ async fn drain_now(session: &SessionInner) -> (Vec<u8>, bool) {
/// wait for a real notify. Without this filter, an idle long-poll /// wait for a real notify. Without this filter, an idle long-poll
/// batch could return in <1 ms on a stale permit and degrade push /// batch could return in <1 ms on a stale permit and degrade push
/// delivery to the client's idle re-poll cadence. /// delivery to the client's idle re-poll cadence.
/// `JoinHandle` newtype that aborts the task on `Drop`. Lets the waiter
/// helpers below be cancel-safe under `tokio::select!`: a plain
/// `Vec<JoinHandle<()>>` only releases its handles via `Drop`, which
/// *detaches* tasks rather than aborting them. The previous shape
/// relied on a trailing `for w in &watchers { w.abort(); }` loop —
/// fine when the function ran to completion, but past the cancellation
/// points (`is_any_drainable().await`, the inner `select!`), so
/// cancelling the loser arm of the phase-2 `select!` left N orphan
/// watchers parked on `notify.notified()`. Each held an
/// `Arc<…Inner>` and could steal a `notify_one()` permit from a
/// future batch's watcher, making that batch wait until the next
/// notify or its deadline. Wrapping in `AbortOnDrop` makes cleanup
/// happen on every exit path, including cancellation.
struct AbortOnDrop(tokio::task::JoinHandle<()>);
impl Drop for AbortOnDrop {
fn drop(&mut self) {
self.0.abort();
}
}
async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration) { async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration) {
if inners.is_empty() { if inners.is_empty() {
return; return;
@@ -400,15 +421,15 @@ async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration
// One watcher per session. Each loops until it observes real state // One watcher per session. Each loops until it observes real state
// (eof set or buffer non-empty) before signaling — see the // (eof set or buffer non-empty) before signaling — see the
// race-safety note on `wait_for_any_drainable` for why. We abort the // race-safety note above. Watchers are held in a Vec of
// watchers on return; the only state they hold is a notify // `AbortOnDrop`, so they're aborted on every exit path —
// subscription, so abort is clean. // including cancellation by an outer `select!`.
let (tx, mut rx) = mpsc::channel::<()>(1); let (tx, mut rx) = mpsc::channel::<()>(1);
let mut watchers = Vec::with_capacity(inners.len()); let mut _watchers: Vec<AbortOnDrop> = Vec::with_capacity(inners.len());
for inner in inners { for inner in inners {
let inner = inner.clone(); let inner = inner.clone();
let tx = tx.clone(); let tx = tx.clone();
watchers.push(tokio::spawn(async move { _watchers.push(AbortOnDrop(tokio::spawn(async move {
loop { loop {
inner.notify.notified().await; inner.notify.notified().await;
if inner.eof.load(Ordering::Acquire) { if inner.eof.load(Ordering::Acquire) {
@@ -423,7 +444,7 @@ async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration
// notify, don't wake the caller. // notify, don't wake the caller.
} }
let _ = tx.try_send(()); let _ = tx.try_send(());
})); })));
} }
drop(tx); drop(tx);
@@ -441,9 +462,9 @@ async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration
} }
} }
for w in &watchers { // No explicit abort loop: `_watchers`'s `AbortOnDrop` entries fire
w.abort(); // on the function returning here AND on the future being dropped
} // mid-await by an outer `select!`.
} }
/// True iff any session is currently drainable: its read buffer has /// True iff any session is currently drainable: its read buffer has
@@ -481,12 +502,14 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
return; return;
} }
// See `AbortOnDrop` and the comment on `wait_for_any_drainable`
// for why watchers must be aborted on every exit path.
let (tx, mut rx) = mpsc::channel::<()>(1); let (tx, mut rx) = mpsc::channel::<()>(1);
let mut watchers = Vec::with_capacity(inners.len()); let mut _watchers: Vec<AbortOnDrop> = Vec::with_capacity(inners.len());
for inner in inners { for inner in inners {
let inner = inner.clone(); let inner = inner.clone();
let tx = tx.clone(); let tx = tx.clone();
watchers.push(tokio::spawn(async move { _watchers.push(AbortOnDrop(tokio::spawn(async move {
loop { loop {
inner.notify.notified().await; inner.notify.notified().await;
if inner.eof.load(Ordering::Acquire) { if inner.eof.load(Ordering::Acquire) {
@@ -499,7 +522,7 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
// prior batch. Loop back, don't wake the caller. // prior batch. Loop back, don't wake the caller.
} }
let _ = tx.try_send(()); let _ = tx.try_send(());
})); })));
} }
drop(tx); drop(tx);
@@ -510,10 +533,6 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
_ = tokio::time::sleep(deadline) => {} _ = tokio::time::sleep(deadline) => {}
} }
} }
for w in &watchers {
w.abort();
}
} }
async fn is_any_udp_drainable(inners: &[Arc<UdpSessionInner>]) -> bool { async fn is_any_udp_drainable(inners: &[Arc<UdpSessionInner>]) -> bool {
@@ -565,7 +584,10 @@ async fn wait_and_drain(session: &SessionInner, max_wait: Duration) -> (Vec<u8>,
struct AppState { struct AppState {
sessions: Arc<Mutex<HashMap<String, ManagedSession>>>, sessions: Arc<Mutex<HashMap<String, ManagedSession>>>,
udp_sessions: Arc<Mutex<HashMap<String, ManagedUdpSession>>>, udp_sessions: Arc<Mutex<HashMap<String, ManagedUdpSession>>>,
auth_key: String, /// Shared, immutable after startup. `Arc<str>` so each `state.clone()`
/// — once per phase-1 spawn in the batch handler — is a refcount bump
/// instead of a fresh String allocation.
auth_key: Arc<str>,
/// Active probing defense: when false (default, production), bad /// Active probing defense: when false (default, production), bad
/// AUTH_KEY responses are a generic-looking 404 with no JSON-shaped /// AUTH_KEY responses are a generic-looking 404 with no JSON-shaped
/// "unauthorized" body — same as a static nginx 404. Active scanners /// "unauthorized" body — same as a static nginx 404. Active scanners
@@ -650,7 +672,7 @@ async fn handle_tunnel(
State(state): State<AppState>, State(state): State<AppState>,
Json(req): Json<TunnelRequest>, Json(req): Json<TunnelRequest>,
) -> axum::response::Response { ) -> axum::response::Response {
if req.k != state.auth_key { if req.k != *state.auth_key {
return decoy_or_unauthorized(state.diagnostic_mode); return decoy_or_unauthorized(state.diagnostic_mode);
} }
let resp: TunnelResponse = match req.op.as_str() { let resp: TunnelResponse = match req.op.as_str() {
@@ -719,7 +741,7 @@ async fn handle_batch(
} }
}; };
if req.k != state.auth_key { if req.k != *state.auth_key {
if state.diagnostic_mode { if state.diagnostic_mode {
let resp = serde_json::to_vec(&BatchResponse { let resp = serde_json::to_vec(&BatchResponse {
r: vec![TunnelResponse::error("unauthorized")], r: vec![TunnelResponse::error("unauthorized")],
@@ -752,8 +774,13 @@ async fn handle_batch(
// still fires from server-speaks-first ports and from the preread // still fires from server-speaks-first ports and from the preread
// timeout fallback path. // timeout fallback path.
let mut results: Vec<(usize, TunnelResponse)> = Vec::with_capacity(req.ops.len()); let mut results: Vec<(usize, TunnelResponse)> = Vec::with_capacity(req.ops.len());
let mut tcp_drains: Vec<(usize, String)> = Vec::new(); // Each drain entry carries the session's `Arc<…Inner>` alongside the
let mut udp_drains: Vec<(usize, String)> = Vec::new(); // sid. Phase 2 drains through the Arc directly so the global sessions
// map lock isn't held across the per-session read_buf / packets
// mutex acquisition — without this, every other batch (and every
// connect/close op) head-of-line-blocks behind the drain.
let mut tcp_drains: Vec<(usize, String, Arc<SessionInner>)> = Vec::new();
let mut udp_drains: Vec<(usize, String, Arc<UdpSessionInner>)> = Vec::new();
// True iff the batch contained any op that performed a real action // True iff the batch contained any op that performed a real action
// upstream — a new connection or a non-empty data write. A batch of // upstream — a new connection or a non-empty data write. A batch of
// only empty "data" / "udp_data" polls (and possibly closes) leaves // only empty "data" / "udp_data" polls (and possibly closes) leaves
@@ -762,8 +789,8 @@ async fn handle_batch(
enum NewConn { enum NewConn {
Connect(TunnelResponse), Connect(TunnelResponse),
ConnectData(Result<String, TunnelResponse>), ConnectData(Result<(String, Arc<SessionInner>), TunnelResponse>),
UdpOpen(Result<String, TunnelResponse>), UdpOpen(Result<(String, Arc<UdpSessionInner>), TunnelResponse>),
} }
let mut new_conn_jobs: JoinSet<(usize, NewConn)> = JoinSet::new(); let mut new_conn_jobs: JoinSet<(usize, NewConn)> = JoinSet::new();
@@ -785,13 +812,11 @@ async fn handle_batch(
let port = op.port; let port = op.port;
let d = op.d.clone(); let d = op.d.clone();
new_conn_jobs.spawn(async move { new_conn_jobs.spawn(async move {
// Drop the returned Arc<SessionInner>: phase 2 below // Keep the returned Arc<SessionInner>: phase 2 drains
// re-looks up each sid under one sessions-map lock, // through it directly, so the global sessions map
// which is cheap. The Arc return is a convenience for // lock doesn't have to be held across the per-session
// the single-op path only. // read_buf.lock().await.
let r = handle_connect_data_phase1(&state, host, port, d) let r = handle_connect_data_phase1(&state, host, port, d).await;
.await
.map(|(sid, _inner)| sid);
(i, NewConn::ConnectData(r)) (i, NewConn::ConnectData(r))
}); });
} }
@@ -808,9 +833,7 @@ async fn handle_batch(
let port = op.port; let port = op.port;
let d = op.d.clone(); let d = op.d.clone();
new_conn_jobs.spawn(async move { new_conn_jobs.spawn(async move {
let r = handle_udp_open_phase1(&state, host, port, d) let r = handle_udp_open_phase1(&state, host, port, d).await;
.await
.map(|(sid, _inner)| sid);
(i, NewConn::UdpOpen(r)) (i, NewConn::UdpOpen(r))
}); });
} }
@@ -820,26 +843,46 @@ async fn handle_batch(
_ => { results.push((i, TunnelResponse::error("missing sid"))); continue; } _ => { results.push((i, TunnelResponse::error("missing sid"))); continue; }
}; };
// Write outbound data // Clone the inner under the map lock and release it
// before any await. The previous shape held the global
// sessions map across last_active.lock(), writer.lock(),
// write_all, and flush — head-of-line-blocking every
// other batch and connect/close op for the duration of
// a single upstream write. The udp_data branch below
// already does the right thing; this matches it.
let inner = {
let sessions = state.sessions.lock().await; let sessions = state.sessions.lock().await;
if let Some(session) = sessions.get(&sid) { sessions.get(&sid).map(|s| s.inner.clone())
*session.inner.last_active.lock().await = Instant::now(); };
if let Some(inner) = inner {
*inner.last_active.lock().await = Instant::now();
if let Some(ref data_b64) = op.d { if let Some(ref data_b64) = op.d {
if !data_b64.is_empty() { if !data_b64.is_empty() {
had_writes_or_connects = true; // Decode first; only count this op as a real
if let Ok(bytes) = B64.decode(data_b64) { // write (and demote the batch out of long-poll)
// after a successful non-empty decode. Mirrors
// the udp_data branch and avoids silently
// dropping bytes on bad base64.
let bytes = match B64.decode(data_b64) {
Ok(b) => b,
Err(e) => {
results.push((
i,
TunnelResponse::error(format!("bad base64: {}", e)),
));
continue;
}
};
if !bytes.is_empty() { if !bytes.is_empty() {
let mut w = session.inner.writer.lock().await; had_writes_or_connects = true;
let mut w = inner.writer.lock().await;
let _ = w.write_all(&bytes).await; let _ = w.write_all(&bytes).await;
let _ = w.flush().await; let _ = w.flush().await;
} }
} }
} }
} tcp_drains.push((i, sid, inner));
drop(sessions);
tcp_drains.push((i, sid));
} else { } else {
drop(sessions);
results.push((i, eof_response(sid))); results.push((i, eof_response(sid)));
} }
} }
@@ -881,7 +924,7 @@ async fn handle_batch(
if had_uplink { if had_uplink {
*inner.last_active.lock().await = Instant::now(); *inner.last_active.lock().await = Instant::now();
} }
udp_drains.push((i, sid)); udp_drains.push((i, sid, inner));
} else { } else {
results.push((i, eof_response(sid))); results.push((i, eof_response(sid)));
} }
@@ -902,9 +945,13 @@ async fn handle_batch(
while let Some(join) = new_conn_jobs.join_next().await { while let Some(join) = new_conn_jobs.join_next().await {
match join { match join {
Ok((i, NewConn::Connect(r))) => results.push((i, r)), Ok((i, NewConn::Connect(r))) => results.push((i, r)),
Ok((i, NewConn::ConnectData(Ok(sid)))) => tcp_drains.push((i, sid)), Ok((i, NewConn::ConnectData(Ok((sid, inner))))) => {
tcp_drains.push((i, sid, inner));
}
Ok((i, NewConn::ConnectData(Err(r)))) => results.push((i, r)), Ok((i, NewConn::ConnectData(Err(r)))) => results.push((i, r)),
Ok((i, NewConn::UdpOpen(Ok(sid)))) => udp_drains.push((i, sid)), Ok((i, NewConn::UdpOpen(Ok((sid, inner))))) => {
udp_drains.push((i, sid, inner));
}
Ok((i, NewConn::UdpOpen(Err(r)))) => results.push((i, r)), Ok((i, NewConn::UdpOpen(Err(r)))) => results.push((i, r)),
Err(e) => { Err(e) => {
tracing::error!("new-connection task panicked: {}", e); tracing::error!("new-connection task panicked: {}", e);
@@ -930,34 +977,38 @@ async fn handle_batch(
LONGPOLL_DEADLINE LONGPOLL_DEADLINE
}; };
let tcp_inners: Vec<Arc<SessionInner>> = { // Phase 1 already gave us each session's Arc<…Inner>, so we
let sessions = state.sessions.lock().await; // don't need to re-acquire the sessions map lock here. Cloning
tcp_drains // the Arc is just a refcount bump.
.iter() let tcp_inners: Vec<Arc<SessionInner>> =
.filter_map(|(_, sid)| sessions.get(sid).map(|s| s.inner.clone())) tcp_drains.iter().map(|(_, _, inner)| inner.clone()).collect();
.collect() let udp_inners: Vec<Arc<UdpSessionInner>> =
}; udp_drains.iter().map(|(_, _, inner)| inner.clone()).collect();
let udp_inners: Vec<Arc<UdpSessionInner>> = {
let sessions = state.udp_sessions.lock().await;
udp_drains
.iter()
.filter_map(|(_, sid)| sessions.get(sid).map(|s| s.inner.clone()))
.collect()
};
// Wait for either side to wake. Running both concurrently means // Wake on whichever side has work first. The previous
// a TCP-only batch isn't slowed by a stale UDP watch list, and // `tokio::join!` was conjunctive — a TCP burst still paid the
// vice versa. // UDP deadline in mixed batches because the UDP waiter had to
tokio::join!( // elapse too. `wait_for_*_drainable` short-circuits on an empty
wait_for_any_drainable(&tcp_inners, deadline), // slice, so we have to skip the empty side; otherwise its
wait_for_any_udp_drainable(&udp_inners, deadline), // instant return would fire the select arm before the other
); // side ever got a chance to wait.
match (tcp_inners.is_empty(), udp_inners.is_empty()) {
(true, true) => {}
(false, true) => wait_for_any_drainable(&tcp_inners, deadline).await,
(true, false) => wait_for_any_udp_drainable(&udp_inners, deadline).await,
(false, false) => {
tokio::select! {
_ = wait_for_any_drainable(&tcp_inners, deadline) => {}
_ = wait_for_any_udp_drainable(&udp_inners, deadline) => {}
}
}
}
if had_writes_or_connects { if had_writes_or_connects {
// Adaptive settle: keep waiting in steps while new data // Adaptive settle: keep waiting in steps while new data
// keeps arriving. Break when: // keeps arriving. Break when:
// 1. No new data arrived in the last step (burst is over) // 1. No new data arrived in the last step (burst is over)
// 2. 500ms max reached // 2. STRAGGLER_SETTLE_MAX reached
let settle_end = Instant::now() + STRAGGLER_SETTLE_MAX; let settle_end = Instant::now() + STRAGGLER_SETTLE_MAX;
let mut prev_tcp_bytes: usize = 0; let mut prev_tcp_bytes: usize = 0;
let mut prev_udp_pkts: usize = 0; let mut prev_udp_pkts: usize = 0;
@@ -997,53 +1048,56 @@ async fn handle_batch(
} }
// ---- TCP drain ---- // ---- TCP drain ----
if !tcp_drains.is_empty() { // Drain through each session's already-cloned Arc so the global
let sessions = state.sessions.lock().await; // sessions map lock isn't held across the per-session
for (i, sid) in &tcp_drains { // read_buf.lock().await.
if let Some(session) = sessions.get(sid) { //
let (data, eof) = drain_now(&session.inner).await; // Cleanup is driven off `drain_now`'s returned `eof`, NOT the
// raw `inner.eof` atomic. When the buffer exceeds
// `TCP_DRAIN_MAX_BYTES`, `drain_now` deliberately returns
// `eof = false` and leaves the tail in the buffer so the
// client can pick it up on the next poll. The previous cleanup
// read the atomic directly, so on a high-throughput session
// that closed mid-burst (issue #460-style) it would remove the
// session and abort the reader_task with the tail still
// buffered, dropping those bytes.
let mut tcp_eof_sids: Vec<String> = Vec::new();
for (i, sid, inner) in &tcp_drains {
let (data, eof) = drain_now(inner).await;
if eof {
tcp_eof_sids.push(sid.clone());
}
results.push((*i, tcp_drain_response(sid.clone(), data, eof))); results.push((*i, tcp_drain_response(sid.clone(), data, eof)));
} else {
results.push((*i, eof_response(sid.clone())));
} }
} if !tcp_eof_sids.is_empty() {
drop(sessions);
// Clean up eof TCP sessions.
let mut sessions = state.sessions.lock().await; let mut sessions = state.sessions.lock().await;
for (_, sid) in &tcp_drains { for sid in &tcp_eof_sids {
if let Some(s) = sessions.get(sid) {
if s.inner.eof.load(Ordering::Acquire) {
if let Some(s) = sessions.remove(sid) { if let Some(s) = sessions.remove(sid) {
s.reader_handle.abort(); s.reader_handle.abort();
tracing::info!("session {} closed by remote (batch)", sid); tracing::info!("session {} closed by remote (batch)", sid);
} }
} }
} }
}
}
// ---- UDP drain ---- // ---- UDP drain ----
if !udp_drains.is_empty() { // Same shape as TCP. `drain_udp_now` currently drains the full
{ // queue with no per-batch cap, so its returned `eof` already
let sessions = state.udp_sessions.lock().await; // matches the atomic — driving cleanup off the drain return
for (i, sid) in &udp_drains { // is future-proofing: if a UDP per-batch packet cap is ever
if let Some(session) = sessions.get(sid) { // added (mirroring `TCP_DRAIN_MAX_BYTES`), the same data-loss
let (packets, eof) = drain_udp_now(&session.inner).await; // trap that motivated the TCP-side fix reappears, and tracking
// eof from the drain return rather than the atomic catches it.
let mut udp_eof_sids: Vec<String> = Vec::new();
for (i, sid, inner) in &udp_drains {
let (packets, eof) = drain_udp_now(inner).await;
if eof {
udp_eof_sids.push(sid.clone());
}
results.push((*i, udp_drain_response(sid.clone(), packets, eof))); results.push((*i, udp_drain_response(sid.clone(), packets, eof)));
} else {
results.push((*i, eof_response(sid.clone())));
} }
} if !udp_eof_sids.is_empty() {
}
// Clean up eof UDP sessions so a future batch with the same
// sid gets the "session not found" eof immediately rather
// than re-checking the (already-stale) eof flag.
let mut sessions = state.udp_sessions.lock().await; let mut sessions = state.udp_sessions.lock().await;
for (_, sid) in &udp_drains { for sid in &udp_eof_sids {
if let Some(s) = sessions.get(sid) {
if s.inner.eof.load(Ordering::Acquire) {
if let Some(s) = sessions.remove(sid) { if let Some(s) = sessions.remove(sid) {
s.reader_handle.abort(); s.reader_handle.abort();
tracing::info!("udp session {} closed by remote (batch)", sid); tracing::info!("udp session {} closed by remote (batch)", sid);
@@ -1051,8 +1105,6 @@ async fn handle_batch(
} }
} }
} }
}
}
// Sort results by original index and build response // Sort results by original index and build response
results.sort_by_key(|(i, _)| *i); results.sort_by_key(|(i, _)| *i);
@@ -1149,11 +1201,12 @@ async fn handle_connect(state: &AppState, host: Option<String>, port: Option<u16
} }
/// Open a session and write the client's first bytes in one round trip. /// Open a session and write the client's first bytes in one round trip.
/// Returns the new sid plus an `Arc<SessionInner>` so unary callers /// Returns the new sid plus an `Arc<SessionInner>`. Both callers keep
/// (`handle_connect_data_single`) can drain the first response without a /// the Arc: the unary path (`handle_connect_data_single`) uses it to
/// second sessions-map lookup. The batch caller drops the Arc — it takes /// drain the first response without a second sessions-map lookup, and
/// a single lock across all drain-bound sessions in phase 2, which is /// the batch path threads it into `tcp_drains` so phase-2 drain runs
/// cheaper than the Arc plumbing would be. /// without holding the global sessions map lock across the per-session
/// `read_buf.lock().await`.
async fn handle_connect_data_phase1( async fn handle_connect_data_phase1(
state: &AppState, state: &AppState,
host: Option<String>, host: Option<String>,
@@ -1274,19 +1327,27 @@ async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<
Some(s) if !s.is_empty() => s, Some(s) if !s.is_empty() => s,
_ => return TunnelResponse::error("missing sid"), _ => return TunnelResponse::error("missing sid"),
}; };
// Clone the inner Arc under the global sessions map lock and release
// the map lock before any await. The previous shape held the map
// across last_active.lock(), writer.lock(), write_all, flush, AND
// wait_and_drain — up to 5 s of head-of-line blocking on every other
// single-op or batch request. Mirrors the batch-handler "data" path.
let inner = {
let sessions = state.sessions.lock().await; let sessions = state.sessions.lock().await;
let session = match sessions.get(&sid) { sessions.get(&sid).map(|s| s.inner.clone())
Some(s) => s, };
let inner = match inner {
Some(i) => i,
None => return TunnelResponse::error("unknown session"), None => return TunnelResponse::error("unknown session"),
}; };
*session.inner.last_active.lock().await = Instant::now(); *inner.last_active.lock().await = Instant::now();
if let Some(ref data_b64) = data { if let Some(ref data_b64) = data {
if !data_b64.is_empty() { if !data_b64.is_empty() {
if let Ok(bytes) = B64.decode(data_b64) { if let Ok(bytes) = B64.decode(data_b64) {
if !bytes.is_empty() { if !bytes.is_empty() {
let mut w = session.inner.writer.lock().await; let mut w = inner.writer.lock().await;
if let Err(e) = w.write_all(&bytes).await { if let Err(e) = w.write_all(&bytes).await {
drop(w); drop(sessions); drop(w);
state.sessions.lock().await.remove(&sid); state.sessions.lock().await.remove(&sid);
return TunnelResponse::error(format!("write failed: {}", e)); return TunnelResponse::error(format!("write failed: {}", e));
} }
@@ -1295,8 +1356,7 @@ async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<
} }
} }
} }
let (data, eof) = wait_and_drain(&session.inner, Duration::from_secs(5)).await; let (data, eof) = wait_and_drain(&inner, Duration::from_secs(5)).await;
drop(sessions);
if eof { if eof {
if let Some(s) = state.sessions.lock().await.remove(&sid) { if let Some(s) = state.sessions.lock().await.remove(&sid) {
s.reader_handle.abort(); s.reader_handle.abort();
@@ -1449,7 +1509,12 @@ async fn main() {
before exposing this tunnel-node to the public internet." before exposing this tunnel-node to the public internet."
); );
} }
let state = AppState { sessions, udp_sessions, auth_key, diagnostic_mode }; let state = AppState {
sessions,
udp_sessions,
auth_key: Arc::from(auth_key),
diagnostic_mode,
};
let app = Router::new() let app = Router::new()
.route("/tunnel", post(handle_tunnel)) .route("/tunnel", post(handle_tunnel))
@@ -2249,4 +2314,151 @@ mod tests {
assert_eq!(r.len(), 1); assert_eq!(r.len(), 1);
assert_eq!(r[0]["eof"], serde_json::Value::Bool(true)); assert_eq!(r[0]["eof"], serde_json::Value::Bool(true));
} }
/// Regression for the cleanup-correctness fix. Previously, the
/// batch handler reaped any session whose `inner.eof` atomic was
/// set, even when `drain_now` had withheld eof to keep tail bytes
/// buffered (i.e. the buffer exceeded `TCP_DRAIN_MAX_BYTES`).
/// Reaping aborted the reader_task and dropped the tail. Cleanup
/// is now driven off the drain's returned `eof`, so an over-cap
/// buffer + atomic eof keeps the session alive through the first
/// poll and only reaps on the drain that actually returns eof.
#[tokio::test]
async fn batch_keeps_over_cap_session_until_tail_is_drained() {
use axum::body::Bytes;
use axum::extract::State;
let state = fresh_state();
let inner = fake_inner().await;
// Prime an over-cap buffer + raw eof. drain_now will return
// TCP_DRAIN_MAX_BYTES bytes with eof=false; the previous
// cleanup would still reap because it read inner.eof directly.
inner
.read_buf
.lock()
.await
.resize(TCP_DRAIN_MAX_BYTES + 4096, 0u8);
inner.eof.store(true, Ordering::Release);
let sid = "over-cap-sid".to_string();
state.sessions.lock().await.insert(
sid.clone(),
ManagedSession {
inner: inner.clone(),
reader_handle: tokio::spawn(async {}),
udpgw_handle: None,
},
);
let body = serde_json::json!({
"k": "test-key",
"ops": [{"op": "data", "sid": &sid}]
})
.to_string();
let _resp = handle_batch(State(state.clone()), Bytes::from(body))
.await
.into_response();
// First poll: session must still be in the map, tail intact.
// The previous code reaped here and dropped the 4096 tail bytes.
{
let sessions = state.sessions.lock().await;
let s = sessions.get(&sid).expect(
"session removed despite tail bytes still buffered; \
drain_now returned eof=false but cleanup ignored that \
and read inner.eof directly",
);
let remaining = s.inner.read_buf.lock().await.len();
assert_eq!(remaining, 4096, "tail must be preserved for next drain");
}
// Second poll: drain_now sees buf.len() ≤ cap AND raw_eof,
// so returns eof=true. Cleanup runs and the session is reaped.
let body2 = serde_json::json!({
"k": "test-key",
"ops": [{"op": "data", "sid": &sid}]
})
.to_string();
let _resp2 = handle_batch(State(state.clone()), Bytes::from(body2))
.await
.into_response();
assert!(
!state.sessions.lock().await.contains_key(&sid),
"session should be reaped on the drain that returns eof=true",
);
}
/// Regression for the `tokio::join!` → `tokio::select!` mixed-drain
/// fix. Before the change, a TCP-ready / UDP-idle pure-poll batch
/// paid the full UDP `LONGPOLL_DEADLINE` (15 s) because the join
/// was conjunctive — both arms had to complete. Under select! the
/// TCP wake returns the response promptly even though UDP is
/// quiet. The bound is loose (1 s) on purpose: real elapsed is
/// in the millisecond range, but the prior bug would have
/// triggered the test timeout instead of the assert.
#[tokio::test]
async fn batch_tcp_ready_does_not_pay_udp_longpoll_deadline() {
use axum::body::Bytes;
use axum::extract::State;
let state = fresh_state();
// TCP session with bytes already buffered → immediately drainable.
let tcp_inner = fake_inner().await;
tcp_inner
.read_buf
.lock()
.await
.extend_from_slice(b"ready");
let tcp_sid = "tcp-sid".to_string();
state.sessions.lock().await.insert(
tcp_sid.clone(),
ManagedSession {
inner: tcp_inner,
reader_handle: tokio::spawn(async {}),
udpgw_handle: None,
},
);
// Idle UDP session — never wakes. Real upstream so udp_open
// succeeds; we just never send anything to it.
let udp_target = UdpSocket::bind(("127.0.0.1", 0)).await.unwrap();
let udp_port = udp_target.local_addr().unwrap().port();
let (udp_sid, _udp_inner) = handle_udp_open_phase1(
&state,
Some("127.0.0.1".into()),
Some(udp_port),
None,
)
.await
.expect("udp open");
// Pure-poll batch (no `d` payload) → had_writes_or_connects =
// false → deadline = LONGPOLL_DEADLINE (15 s). Under the
// previous tokio::join! wait, the UDP arm would have held the
// response open for the full window even though TCP was
// already drainable.
let body = serde_json::json!({
"k": "test-key",
"ops": [
{"op": "data", "sid": &tcp_sid},
{"op": "udp_data", "sid": &udp_sid},
]
})
.to_string();
let t0 = Instant::now();
let _resp = handle_batch(State(state.clone()), Bytes::from(body))
.await
.into_response();
let elapsed = t0.elapsed();
assert!(
elapsed < Duration::from_secs(1),
"TCP-ready / UDP-idle pure-poll batch must not pay \
LONGPOLL_DEADLINE; elapsed={:?}",
elapsed,
);
}
} }