mirror of
https://github.com/therealaleph/MasterHttpRelayVPN-RUST.git
synced 2026-05-17 21:24:48 +03:00
fix(tunnel-node): batch drain correctness and lock contention (#695)
* fix(tunnel-node): batch drain correctness and lock contention * fix(tunnel-node): single-op lock contention and batch base64 consistency
This commit is contained in:
+340
-128
@@ -393,6 +393,27 @@ async fn drain_now(session: &SessionInner) -> (Vec<u8>, bool) {
|
||||
/// wait for a real notify. Without this filter, an idle long-poll
|
||||
/// batch could return in <1 ms on a stale permit and degrade push
|
||||
/// delivery to the client's idle re-poll cadence.
|
||||
/// `JoinHandle` newtype that aborts the task on `Drop`. Lets the waiter
|
||||
/// helpers below be cancel-safe under `tokio::select!`: a plain
|
||||
/// `Vec<JoinHandle<()>>` only releases its handles via `Drop`, which
|
||||
/// *detaches* tasks rather than aborting them. The previous shape
|
||||
/// relied on a trailing `for w in &watchers { w.abort(); }` loop —
|
||||
/// fine when the function ran to completion, but past the cancellation
|
||||
/// points (`is_any_drainable().await`, the inner `select!`), so
|
||||
/// cancelling the loser arm of the phase-2 `select!` left N orphan
|
||||
/// watchers parked on `notify.notified()`. Each held an
|
||||
/// `Arc<…Inner>` and could steal a `notify_one()` permit from a
|
||||
/// future batch's watcher, making that batch wait until the next
|
||||
/// notify or its deadline. Wrapping in `AbortOnDrop` makes cleanup
|
||||
/// happen on every exit path, including cancellation.
|
||||
struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
||||
|
||||
impl Drop for AbortOnDrop {
|
||||
fn drop(&mut self) {
|
||||
self.0.abort();
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration) {
|
||||
if inners.is_empty() {
|
||||
return;
|
||||
@@ -400,15 +421,15 @@ async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration
|
||||
|
||||
// One watcher per session. Each loops until it observes real state
|
||||
// (eof set or buffer non-empty) before signaling — see the
|
||||
// race-safety note on `wait_for_any_drainable` for why. We abort the
|
||||
// watchers on return; the only state they hold is a notify
|
||||
// subscription, so abort is clean.
|
||||
// race-safety note above. Watchers are held in a Vec of
|
||||
// `AbortOnDrop`, so they're aborted on every exit path —
|
||||
// including cancellation by an outer `select!`.
|
||||
let (tx, mut rx) = mpsc::channel::<()>(1);
|
||||
let mut watchers = Vec::with_capacity(inners.len());
|
||||
let mut _watchers: Vec<AbortOnDrop> = Vec::with_capacity(inners.len());
|
||||
for inner in inners {
|
||||
let inner = inner.clone();
|
||||
let tx = tx.clone();
|
||||
watchers.push(tokio::spawn(async move {
|
||||
_watchers.push(AbortOnDrop(tokio::spawn(async move {
|
||||
loop {
|
||||
inner.notify.notified().await;
|
||||
if inner.eof.load(Ordering::Acquire) {
|
||||
@@ -423,7 +444,7 @@ async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration
|
||||
// notify, don't wake the caller.
|
||||
}
|
||||
let _ = tx.try_send(());
|
||||
}));
|
||||
})));
|
||||
}
|
||||
drop(tx);
|
||||
|
||||
@@ -441,9 +462,9 @@ async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration
|
||||
}
|
||||
}
|
||||
|
||||
for w in &watchers {
|
||||
w.abort();
|
||||
}
|
||||
// No explicit abort loop: `_watchers`'s `AbortOnDrop` entries fire
|
||||
// on the function returning here AND on the future being dropped
|
||||
// mid-await by an outer `select!`.
|
||||
}
|
||||
|
||||
/// True iff any session is currently drainable: its read buffer has
|
||||
@@ -481,12 +502,14 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
|
||||
return;
|
||||
}
|
||||
|
||||
// See `AbortOnDrop` and the comment on `wait_for_any_drainable`
|
||||
// for why watchers must be aborted on every exit path.
|
||||
let (tx, mut rx) = mpsc::channel::<()>(1);
|
||||
let mut watchers = Vec::with_capacity(inners.len());
|
||||
let mut _watchers: Vec<AbortOnDrop> = Vec::with_capacity(inners.len());
|
||||
for inner in inners {
|
||||
let inner = inner.clone();
|
||||
let tx = tx.clone();
|
||||
watchers.push(tokio::spawn(async move {
|
||||
_watchers.push(AbortOnDrop(tokio::spawn(async move {
|
||||
loop {
|
||||
inner.notify.notified().await;
|
||||
if inner.eof.load(Ordering::Acquire) {
|
||||
@@ -499,7 +522,7 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
|
||||
// prior batch. Loop back, don't wake the caller.
|
||||
}
|
||||
let _ = tx.try_send(());
|
||||
}));
|
||||
})));
|
||||
}
|
||||
drop(tx);
|
||||
|
||||
@@ -510,10 +533,6 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
|
||||
_ = tokio::time::sleep(deadline) => {}
|
||||
}
|
||||
}
|
||||
|
||||
for w in &watchers {
|
||||
w.abort();
|
||||
}
|
||||
}
|
||||
|
||||
async fn is_any_udp_drainable(inners: &[Arc<UdpSessionInner>]) -> bool {
|
||||
@@ -565,7 +584,10 @@ async fn wait_and_drain(session: &SessionInner, max_wait: Duration) -> (Vec<u8>,
|
||||
struct AppState {
|
||||
sessions: Arc<Mutex<HashMap<String, ManagedSession>>>,
|
||||
udp_sessions: Arc<Mutex<HashMap<String, ManagedUdpSession>>>,
|
||||
auth_key: String,
|
||||
/// Shared, immutable after startup. `Arc<str>` so each `state.clone()`
|
||||
/// — once per phase-1 spawn in the batch handler — is a refcount bump
|
||||
/// instead of a fresh String allocation.
|
||||
auth_key: Arc<str>,
|
||||
/// Active probing defense: when false (default, production), bad
|
||||
/// AUTH_KEY responses are a generic-looking 404 with no JSON-shaped
|
||||
/// "unauthorized" body — same as a static nginx 404. Active scanners
|
||||
@@ -650,7 +672,7 @@ async fn handle_tunnel(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<TunnelRequest>,
|
||||
) -> axum::response::Response {
|
||||
if req.k != state.auth_key {
|
||||
if req.k != *state.auth_key {
|
||||
return decoy_or_unauthorized(state.diagnostic_mode);
|
||||
}
|
||||
let resp: TunnelResponse = match req.op.as_str() {
|
||||
@@ -719,7 +741,7 @@ async fn handle_batch(
|
||||
}
|
||||
};
|
||||
|
||||
if req.k != state.auth_key {
|
||||
if req.k != *state.auth_key {
|
||||
if state.diagnostic_mode {
|
||||
let resp = serde_json::to_vec(&BatchResponse {
|
||||
r: vec![TunnelResponse::error("unauthorized")],
|
||||
@@ -752,8 +774,13 @@ async fn handle_batch(
|
||||
// still fires from server-speaks-first ports and from the preread
|
||||
// timeout fallback path.
|
||||
let mut results: Vec<(usize, TunnelResponse)> = Vec::with_capacity(req.ops.len());
|
||||
let mut tcp_drains: Vec<(usize, String)> = Vec::new();
|
||||
let mut udp_drains: Vec<(usize, String)> = Vec::new();
|
||||
// Each drain entry carries the session's `Arc<…Inner>` alongside the
|
||||
// sid. Phase 2 drains through the Arc directly so the global sessions
|
||||
// map lock isn't held across the per-session read_buf / packets
|
||||
// mutex acquisition — without this, every other batch (and every
|
||||
// connect/close op) head-of-line-blocks behind the drain.
|
||||
let mut tcp_drains: Vec<(usize, String, Arc<SessionInner>)> = Vec::new();
|
||||
let mut udp_drains: Vec<(usize, String, Arc<UdpSessionInner>)> = Vec::new();
|
||||
// True iff the batch contained any op that performed a real action
|
||||
// upstream — a new connection or a non-empty data write. A batch of
|
||||
// only empty "data" / "udp_data" polls (and possibly closes) leaves
|
||||
@@ -762,8 +789,8 @@ async fn handle_batch(
|
||||
|
||||
enum NewConn {
|
||||
Connect(TunnelResponse),
|
||||
ConnectData(Result<String, TunnelResponse>),
|
||||
UdpOpen(Result<String, TunnelResponse>),
|
||||
ConnectData(Result<(String, Arc<SessionInner>), TunnelResponse>),
|
||||
UdpOpen(Result<(String, Arc<UdpSessionInner>), TunnelResponse>),
|
||||
}
|
||||
let mut new_conn_jobs: JoinSet<(usize, NewConn)> = JoinSet::new();
|
||||
|
||||
@@ -785,13 +812,11 @@ async fn handle_batch(
|
||||
let port = op.port;
|
||||
let d = op.d.clone();
|
||||
new_conn_jobs.spawn(async move {
|
||||
// Drop the returned Arc<SessionInner>: phase 2 below
|
||||
// re-looks up each sid under one sessions-map lock,
|
||||
// which is cheap. The Arc return is a convenience for
|
||||
// the single-op path only.
|
||||
let r = handle_connect_data_phase1(&state, host, port, d)
|
||||
.await
|
||||
.map(|(sid, _inner)| sid);
|
||||
// Keep the returned Arc<SessionInner>: phase 2 drains
|
||||
// through it directly, so the global sessions map
|
||||
// lock doesn't have to be held across the per-session
|
||||
// read_buf.lock().await.
|
||||
let r = handle_connect_data_phase1(&state, host, port, d).await;
|
||||
(i, NewConn::ConnectData(r))
|
||||
});
|
||||
}
|
||||
@@ -808,9 +833,7 @@ async fn handle_batch(
|
||||
let port = op.port;
|
||||
let d = op.d.clone();
|
||||
new_conn_jobs.spawn(async move {
|
||||
let r = handle_udp_open_phase1(&state, host, port, d)
|
||||
.await
|
||||
.map(|(sid, _inner)| sid);
|
||||
let r = handle_udp_open_phase1(&state, host, port, d).await;
|
||||
(i, NewConn::UdpOpen(r))
|
||||
});
|
||||
}
|
||||
@@ -820,26 +843,46 @@ async fn handle_batch(
|
||||
_ => { results.push((i, TunnelResponse::error("missing sid"))); continue; }
|
||||
};
|
||||
|
||||
// Write outbound data
|
||||
let sessions = state.sessions.lock().await;
|
||||
if let Some(session) = sessions.get(&sid) {
|
||||
*session.inner.last_active.lock().await = Instant::now();
|
||||
// Clone the inner under the map lock and release it
|
||||
// before any await. The previous shape held the global
|
||||
// sessions map across last_active.lock(), writer.lock(),
|
||||
// write_all, and flush — head-of-line-blocking every
|
||||
// other batch and connect/close op for the duration of
|
||||
// a single upstream write. The udp_data branch below
|
||||
// already does the right thing; this matches it.
|
||||
let inner = {
|
||||
let sessions = state.sessions.lock().await;
|
||||
sessions.get(&sid).map(|s| s.inner.clone())
|
||||
};
|
||||
if let Some(inner) = inner {
|
||||
*inner.last_active.lock().await = Instant::now();
|
||||
if let Some(ref data_b64) = op.d {
|
||||
if !data_b64.is_empty() {
|
||||
had_writes_or_connects = true;
|
||||
if let Ok(bytes) = B64.decode(data_b64) {
|
||||
if !bytes.is_empty() {
|
||||
let mut w = session.inner.writer.lock().await;
|
||||
let _ = w.write_all(&bytes).await;
|
||||
let _ = w.flush().await;
|
||||
// Decode first; only count this op as a real
|
||||
// write (and demote the batch out of long-poll)
|
||||
// after a successful non-empty decode. Mirrors
|
||||
// the udp_data branch and avoids silently
|
||||
// dropping bytes on bad base64.
|
||||
let bytes = match B64.decode(data_b64) {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
results.push((
|
||||
i,
|
||||
TunnelResponse::error(format!("bad base64: {}", e)),
|
||||
));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
if !bytes.is_empty() {
|
||||
had_writes_or_connects = true;
|
||||
let mut w = inner.writer.lock().await;
|
||||
let _ = w.write_all(&bytes).await;
|
||||
let _ = w.flush().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
drop(sessions);
|
||||
tcp_drains.push((i, sid));
|
||||
tcp_drains.push((i, sid, inner));
|
||||
} else {
|
||||
drop(sessions);
|
||||
results.push((i, eof_response(sid)));
|
||||
}
|
||||
}
|
||||
@@ -881,7 +924,7 @@ async fn handle_batch(
|
||||
if had_uplink {
|
||||
*inner.last_active.lock().await = Instant::now();
|
||||
}
|
||||
udp_drains.push((i, sid));
|
||||
udp_drains.push((i, sid, inner));
|
||||
} else {
|
||||
results.push((i, eof_response(sid)));
|
||||
}
|
||||
@@ -902,9 +945,13 @@ async fn handle_batch(
|
||||
while let Some(join) = new_conn_jobs.join_next().await {
|
||||
match join {
|
||||
Ok((i, NewConn::Connect(r))) => results.push((i, r)),
|
||||
Ok((i, NewConn::ConnectData(Ok(sid)))) => tcp_drains.push((i, sid)),
|
||||
Ok((i, NewConn::ConnectData(Ok((sid, inner))))) => {
|
||||
tcp_drains.push((i, sid, inner));
|
||||
}
|
||||
Ok((i, NewConn::ConnectData(Err(r)))) => results.push((i, r)),
|
||||
Ok((i, NewConn::UdpOpen(Ok(sid)))) => udp_drains.push((i, sid)),
|
||||
Ok((i, NewConn::UdpOpen(Ok((sid, inner))))) => {
|
||||
udp_drains.push((i, sid, inner));
|
||||
}
|
||||
Ok((i, NewConn::UdpOpen(Err(r)))) => results.push((i, r)),
|
||||
Err(e) => {
|
||||
tracing::error!("new-connection task panicked: {}", e);
|
||||
@@ -930,34 +977,38 @@ async fn handle_batch(
|
||||
LONGPOLL_DEADLINE
|
||||
};
|
||||
|
||||
let tcp_inners: Vec<Arc<SessionInner>> = {
|
||||
let sessions = state.sessions.lock().await;
|
||||
tcp_drains
|
||||
.iter()
|
||||
.filter_map(|(_, sid)| sessions.get(sid).map(|s| s.inner.clone()))
|
||||
.collect()
|
||||
};
|
||||
let udp_inners: Vec<Arc<UdpSessionInner>> = {
|
||||
let sessions = state.udp_sessions.lock().await;
|
||||
udp_drains
|
||||
.iter()
|
||||
.filter_map(|(_, sid)| sessions.get(sid).map(|s| s.inner.clone()))
|
||||
.collect()
|
||||
};
|
||||
// Phase 1 already gave us each session's Arc<…Inner>, so we
|
||||
// don't need to re-acquire the sessions map lock here. Cloning
|
||||
// the Arc is just a refcount bump.
|
||||
let tcp_inners: Vec<Arc<SessionInner>> =
|
||||
tcp_drains.iter().map(|(_, _, inner)| inner.clone()).collect();
|
||||
let udp_inners: Vec<Arc<UdpSessionInner>> =
|
||||
udp_drains.iter().map(|(_, _, inner)| inner.clone()).collect();
|
||||
|
||||
// Wait for either side to wake. Running both concurrently means
|
||||
// a TCP-only batch isn't slowed by a stale UDP watch list, and
|
||||
// vice versa.
|
||||
tokio::join!(
|
||||
wait_for_any_drainable(&tcp_inners, deadline),
|
||||
wait_for_any_udp_drainable(&udp_inners, deadline),
|
||||
);
|
||||
// Wake on whichever side has work first. The previous
|
||||
// `tokio::join!` was conjunctive — a TCP burst still paid the
|
||||
// UDP deadline in mixed batches because the UDP waiter had to
|
||||
// elapse too. `wait_for_*_drainable` short-circuits on an empty
|
||||
// slice, so we have to skip the empty side; otherwise its
|
||||
// instant return would fire the select arm before the other
|
||||
// side ever got a chance to wait.
|
||||
match (tcp_inners.is_empty(), udp_inners.is_empty()) {
|
||||
(true, true) => {}
|
||||
(false, true) => wait_for_any_drainable(&tcp_inners, deadline).await,
|
||||
(true, false) => wait_for_any_udp_drainable(&udp_inners, deadline).await,
|
||||
(false, false) => {
|
||||
tokio::select! {
|
||||
_ = wait_for_any_drainable(&tcp_inners, deadline) => {}
|
||||
_ = wait_for_any_udp_drainable(&udp_inners, deadline) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if had_writes_or_connects {
|
||||
// Adaptive settle: keep waiting in steps while new data
|
||||
// keeps arriving. Break when:
|
||||
// 1. No new data arrived in the last step (burst is over)
|
||||
// 2. 500ms max reached
|
||||
// 2. STRAGGLER_SETTLE_MAX reached
|
||||
let settle_end = Instant::now() + STRAGGLER_SETTLE_MAX;
|
||||
let mut prev_tcp_bytes: usize = 0;
|
||||
let mut prev_udp_pkts: usize = 0;
|
||||
@@ -997,58 +1048,59 @@ async fn handle_batch(
|
||||
}
|
||||
|
||||
// ---- TCP drain ----
|
||||
if !tcp_drains.is_empty() {
|
||||
let sessions = state.sessions.lock().await;
|
||||
for (i, sid) in &tcp_drains {
|
||||
if let Some(session) = sessions.get(sid) {
|
||||
let (data, eof) = drain_now(&session.inner).await;
|
||||
results.push((*i, tcp_drain_response(sid.clone(), data, eof)));
|
||||
} else {
|
||||
results.push((*i, eof_response(sid.clone())));
|
||||
}
|
||||
// Drain through each session's already-cloned Arc so the global
|
||||
// sessions map lock isn't held across the per-session
|
||||
// read_buf.lock().await.
|
||||
//
|
||||
// Cleanup is driven off `drain_now`'s returned `eof`, NOT the
|
||||
// raw `inner.eof` atomic. When the buffer exceeds
|
||||
// `TCP_DRAIN_MAX_BYTES`, `drain_now` deliberately returns
|
||||
// `eof = false` and leaves the tail in the buffer so the
|
||||
// client can pick it up on the next poll. The previous cleanup
|
||||
// read the atomic directly, so on a high-throughput session
|
||||
// that closed mid-burst (issue #460-style) it would remove the
|
||||
// session and abort the reader_task with the tail still
|
||||
// buffered, dropping those bytes.
|
||||
let mut tcp_eof_sids: Vec<String> = Vec::new();
|
||||
for (i, sid, inner) in &tcp_drains {
|
||||
let (data, eof) = drain_now(inner).await;
|
||||
if eof {
|
||||
tcp_eof_sids.push(sid.clone());
|
||||
}
|
||||
drop(sessions);
|
||||
|
||||
// Clean up eof TCP sessions.
|
||||
results.push((*i, tcp_drain_response(sid.clone(), data, eof)));
|
||||
}
|
||||
if !tcp_eof_sids.is_empty() {
|
||||
let mut sessions = state.sessions.lock().await;
|
||||
for (_, sid) in &tcp_drains {
|
||||
if let Some(s) = sessions.get(sid) {
|
||||
if s.inner.eof.load(Ordering::Acquire) {
|
||||
if let Some(s) = sessions.remove(sid) {
|
||||
s.reader_handle.abort();
|
||||
tracing::info!("session {} closed by remote (batch)", sid);
|
||||
}
|
||||
}
|
||||
for sid in &tcp_eof_sids {
|
||||
if let Some(s) = sessions.remove(sid) {
|
||||
s.reader_handle.abort();
|
||||
tracing::info!("session {} closed by remote (batch)", sid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- UDP drain ----
|
||||
if !udp_drains.is_empty() {
|
||||
{
|
||||
let sessions = state.udp_sessions.lock().await;
|
||||
for (i, sid) in &udp_drains {
|
||||
if let Some(session) = sessions.get(sid) {
|
||||
let (packets, eof) = drain_udp_now(&session.inner).await;
|
||||
results.push((*i, udp_drain_response(sid.clone(), packets, eof)));
|
||||
} else {
|
||||
results.push((*i, eof_response(sid.clone())));
|
||||
}
|
||||
}
|
||||
// Same shape as TCP. `drain_udp_now` currently drains the full
|
||||
// queue with no per-batch cap, so its returned `eof` already
|
||||
// matches the atomic — driving cleanup off the drain return
|
||||
// is future-proofing: if a UDP per-batch packet cap is ever
|
||||
// added (mirroring `TCP_DRAIN_MAX_BYTES`), the same data-loss
|
||||
// trap that motivated the TCP-side fix reappears, and tracking
|
||||
// eof from the drain return rather than the atomic catches it.
|
||||
let mut udp_eof_sids: Vec<String> = Vec::new();
|
||||
for (i, sid, inner) in &udp_drains {
|
||||
let (packets, eof) = drain_udp_now(inner).await;
|
||||
if eof {
|
||||
udp_eof_sids.push(sid.clone());
|
||||
}
|
||||
|
||||
// Clean up eof UDP sessions so a future batch with the same
|
||||
// sid gets the "session not found" eof immediately rather
|
||||
// than re-checking the (already-stale) eof flag.
|
||||
results.push((*i, udp_drain_response(sid.clone(), packets, eof)));
|
||||
}
|
||||
if !udp_eof_sids.is_empty() {
|
||||
let mut sessions = state.udp_sessions.lock().await;
|
||||
for (_, sid) in &udp_drains {
|
||||
if let Some(s) = sessions.get(sid) {
|
||||
if s.inner.eof.load(Ordering::Acquire) {
|
||||
if let Some(s) = sessions.remove(sid) {
|
||||
s.reader_handle.abort();
|
||||
tracing::info!("udp session {} closed by remote (batch)", sid);
|
||||
}
|
||||
}
|
||||
for sid in &udp_eof_sids {
|
||||
if let Some(s) = sessions.remove(sid) {
|
||||
s.reader_handle.abort();
|
||||
tracing::info!("udp session {} closed by remote (batch)", sid);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1149,11 +1201,12 @@ async fn handle_connect(state: &AppState, host: Option<String>, port: Option<u16
|
||||
}
|
||||
|
||||
/// Open a session and write the client's first bytes in one round trip.
|
||||
/// Returns the new sid plus an `Arc<SessionInner>` so unary callers
|
||||
/// (`handle_connect_data_single`) can drain the first response without a
|
||||
/// second sessions-map lookup. The batch caller drops the Arc — it takes
|
||||
/// a single lock across all drain-bound sessions in phase 2, which is
|
||||
/// cheaper than the Arc plumbing would be.
|
||||
/// Returns the new sid plus an `Arc<SessionInner>`. Both callers keep
|
||||
/// the Arc: the unary path (`handle_connect_data_single`) uses it to
|
||||
/// drain the first response without a second sessions-map lookup, and
|
||||
/// the batch path threads it into `tcp_drains` so phase-2 drain runs
|
||||
/// without holding the global sessions map lock across the per-session
|
||||
/// `read_buf.lock().await`.
|
||||
async fn handle_connect_data_phase1(
|
||||
state: &AppState,
|
||||
host: Option<String>,
|
||||
@@ -1274,19 +1327,27 @@ async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<
|
||||
Some(s) if !s.is_empty() => s,
|
||||
_ => return TunnelResponse::error("missing sid"),
|
||||
};
|
||||
let sessions = state.sessions.lock().await;
|
||||
let session = match sessions.get(&sid) {
|
||||
Some(s) => s,
|
||||
// Clone the inner Arc under the global sessions map lock and release
|
||||
// the map lock before any await. The previous shape held the map
|
||||
// across last_active.lock(), writer.lock(), write_all, flush, AND
|
||||
// wait_and_drain — up to 5 s of head-of-line blocking on every other
|
||||
// single-op or batch request. Mirrors the batch-handler "data" path.
|
||||
let inner = {
|
||||
let sessions = state.sessions.lock().await;
|
||||
sessions.get(&sid).map(|s| s.inner.clone())
|
||||
};
|
||||
let inner = match inner {
|
||||
Some(i) => i,
|
||||
None => return TunnelResponse::error("unknown session"),
|
||||
};
|
||||
*session.inner.last_active.lock().await = Instant::now();
|
||||
*inner.last_active.lock().await = Instant::now();
|
||||
if let Some(ref data_b64) = data {
|
||||
if !data_b64.is_empty() {
|
||||
if let Ok(bytes) = B64.decode(data_b64) {
|
||||
if !bytes.is_empty() {
|
||||
let mut w = session.inner.writer.lock().await;
|
||||
let mut w = inner.writer.lock().await;
|
||||
if let Err(e) = w.write_all(&bytes).await {
|
||||
drop(w); drop(sessions);
|
||||
drop(w);
|
||||
state.sessions.lock().await.remove(&sid);
|
||||
return TunnelResponse::error(format!("write failed: {}", e));
|
||||
}
|
||||
@@ -1295,8 +1356,7 @@ async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<
|
||||
}
|
||||
}
|
||||
}
|
||||
let (data, eof) = wait_and_drain(&session.inner, Duration::from_secs(5)).await;
|
||||
drop(sessions);
|
||||
let (data, eof) = wait_and_drain(&inner, Duration::from_secs(5)).await;
|
||||
if eof {
|
||||
if let Some(s) = state.sessions.lock().await.remove(&sid) {
|
||||
s.reader_handle.abort();
|
||||
@@ -1449,7 +1509,12 @@ async fn main() {
|
||||
before exposing this tunnel-node to the public internet."
|
||||
);
|
||||
}
|
||||
let state = AppState { sessions, udp_sessions, auth_key, diagnostic_mode };
|
||||
let state = AppState {
|
||||
sessions,
|
||||
udp_sessions,
|
||||
auth_key: Arc::from(auth_key),
|
||||
diagnostic_mode,
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/tunnel", post(handle_tunnel))
|
||||
@@ -2249,4 +2314,151 @@ mod tests {
|
||||
assert_eq!(r.len(), 1);
|
||||
assert_eq!(r[0]["eof"], serde_json::Value::Bool(true));
|
||||
}
|
||||
|
||||
/// Regression for the cleanup-correctness fix. Previously, the
|
||||
/// batch handler reaped any session whose `inner.eof` atomic was
|
||||
/// set, even when `drain_now` had withheld eof to keep tail bytes
|
||||
/// buffered (i.e. the buffer exceeded `TCP_DRAIN_MAX_BYTES`).
|
||||
/// Reaping aborted the reader_task and dropped the tail. Cleanup
|
||||
/// is now driven off the drain's returned `eof`, so an over-cap
|
||||
/// buffer + atomic eof keeps the session alive through the first
|
||||
/// poll and only reaps on the drain that actually returns eof.
|
||||
#[tokio::test]
|
||||
async fn batch_keeps_over_cap_session_until_tail_is_drained() {
|
||||
use axum::body::Bytes;
|
||||
use axum::extract::State;
|
||||
|
||||
let state = fresh_state();
|
||||
let inner = fake_inner().await;
|
||||
// Prime an over-cap buffer + raw eof. drain_now will return
|
||||
// TCP_DRAIN_MAX_BYTES bytes with eof=false; the previous
|
||||
// cleanup would still reap because it read inner.eof directly.
|
||||
inner
|
||||
.read_buf
|
||||
.lock()
|
||||
.await
|
||||
.resize(TCP_DRAIN_MAX_BYTES + 4096, 0u8);
|
||||
inner.eof.store(true, Ordering::Release);
|
||||
|
||||
let sid = "over-cap-sid".to_string();
|
||||
state.sessions.lock().await.insert(
|
||||
sid.clone(),
|
||||
ManagedSession {
|
||||
inner: inner.clone(),
|
||||
reader_handle: tokio::spawn(async {}),
|
||||
udpgw_handle: None,
|
||||
},
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"k": "test-key",
|
||||
"ops": [{"op": "data", "sid": &sid}]
|
||||
})
|
||||
.to_string();
|
||||
let _resp = handle_batch(State(state.clone()), Bytes::from(body))
|
||||
.await
|
||||
.into_response();
|
||||
|
||||
// First poll: session must still be in the map, tail intact.
|
||||
// The previous code reaped here and dropped the 4096 tail bytes.
|
||||
{
|
||||
let sessions = state.sessions.lock().await;
|
||||
let s = sessions.get(&sid).expect(
|
||||
"session removed despite tail bytes still buffered; \
|
||||
drain_now returned eof=false but cleanup ignored that \
|
||||
and read inner.eof directly",
|
||||
);
|
||||
let remaining = s.inner.read_buf.lock().await.len();
|
||||
assert_eq!(remaining, 4096, "tail must be preserved for next drain");
|
||||
}
|
||||
|
||||
// Second poll: drain_now sees buf.len() ≤ cap AND raw_eof,
|
||||
// so returns eof=true. Cleanup runs and the session is reaped.
|
||||
let body2 = serde_json::json!({
|
||||
"k": "test-key",
|
||||
"ops": [{"op": "data", "sid": &sid}]
|
||||
})
|
||||
.to_string();
|
||||
let _resp2 = handle_batch(State(state.clone()), Bytes::from(body2))
|
||||
.await
|
||||
.into_response();
|
||||
|
||||
assert!(
|
||||
!state.sessions.lock().await.contains_key(&sid),
|
||||
"session should be reaped on the drain that returns eof=true",
|
||||
);
|
||||
}
|
||||
|
||||
/// Regression for the `tokio::join!` → `tokio::select!` mixed-drain
|
||||
/// fix. Before the change, a TCP-ready / UDP-idle pure-poll batch
|
||||
/// paid the full UDP `LONGPOLL_DEADLINE` (15 s) because the join
|
||||
/// was conjunctive — both arms had to complete. Under select! the
|
||||
/// TCP wake returns the response promptly even though UDP is
|
||||
/// quiet. The bound is loose (1 s) on purpose: real elapsed is
|
||||
/// in the millisecond range, but the prior bug would have
|
||||
/// triggered the test timeout instead of the assert.
|
||||
#[tokio::test]
|
||||
async fn batch_tcp_ready_does_not_pay_udp_longpoll_deadline() {
|
||||
use axum::body::Bytes;
|
||||
use axum::extract::State;
|
||||
|
||||
let state = fresh_state();
|
||||
|
||||
// TCP session with bytes already buffered → immediately drainable.
|
||||
let tcp_inner = fake_inner().await;
|
||||
tcp_inner
|
||||
.read_buf
|
||||
.lock()
|
||||
.await
|
||||
.extend_from_slice(b"ready");
|
||||
let tcp_sid = "tcp-sid".to_string();
|
||||
state.sessions.lock().await.insert(
|
||||
tcp_sid.clone(),
|
||||
ManagedSession {
|
||||
inner: tcp_inner,
|
||||
reader_handle: tokio::spawn(async {}),
|
||||
udpgw_handle: None,
|
||||
},
|
||||
);
|
||||
|
||||
// Idle UDP session — never wakes. Real upstream so udp_open
|
||||
// succeeds; we just never send anything to it.
|
||||
let udp_target = UdpSocket::bind(("127.0.0.1", 0)).await.unwrap();
|
||||
let udp_port = udp_target.local_addr().unwrap().port();
|
||||
let (udp_sid, _udp_inner) = handle_udp_open_phase1(
|
||||
&state,
|
||||
Some("127.0.0.1".into()),
|
||||
Some(udp_port),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("udp open");
|
||||
|
||||
// Pure-poll batch (no `d` payload) → had_writes_or_connects =
|
||||
// false → deadline = LONGPOLL_DEADLINE (15 s). Under the
|
||||
// previous tokio::join! wait, the UDP arm would have held the
|
||||
// response open for the full window even though TCP was
|
||||
// already drainable.
|
||||
let body = serde_json::json!({
|
||||
"k": "test-key",
|
||||
"ops": [
|
||||
{"op": "data", "sid": &tcp_sid},
|
||||
{"op": "udp_data", "sid": &udp_sid},
|
||||
]
|
||||
})
|
||||
.to_string();
|
||||
|
||||
let t0 = Instant::now();
|
||||
let _resp = handle_batch(State(state.clone()), Bytes::from(body))
|
||||
.await
|
||||
.into_response();
|
||||
let elapsed = t0.elapsed();
|
||||
|
||||
assert!(
|
||||
elapsed < Duration::from_secs(1),
|
||||
"TCP-ready / UDP-idle pure-poll batch must not pay \
|
||||
LONGPOLL_DEADLINE; elapsed={:?}",
|
||||
elapsed,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user