mirror of
https://github.com/therealaleph/MasterHttpRelayVPN-RUST.git
synced 2026-05-19 08:04:39 +03:00
fix(tunnel-node): batch drain correctness and lock contention (#695)
* fix(tunnel-node): batch drain correctness and lock contention * fix(tunnel-node): single-op lock contention and batch base64 consistency
This commit is contained in:
+328
-116
@@ -393,6 +393,27 @@ async fn drain_now(session: &SessionInner) -> (Vec<u8>, bool) {
|
|||||||
/// wait for a real notify. Without this filter, an idle long-poll
|
/// wait for a real notify. Without this filter, an idle long-poll
|
||||||
/// batch could return in <1 ms on a stale permit and degrade push
|
/// batch could return in <1 ms on a stale permit and degrade push
|
||||||
/// delivery to the client's idle re-poll cadence.
|
/// delivery to the client's idle re-poll cadence.
|
||||||
|
/// `JoinHandle` newtype that aborts the task on `Drop`. Lets the waiter
|
||||||
|
/// helpers below be cancel-safe under `tokio::select!`: a plain
|
||||||
|
/// `Vec<JoinHandle<()>>` only releases its handles via `Drop`, which
|
||||||
|
/// *detaches* tasks rather than aborting them. The previous shape
|
||||||
|
/// relied on a trailing `for w in &watchers { w.abort(); }` loop —
|
||||||
|
/// fine when the function ran to completion, but past the cancellation
|
||||||
|
/// points (`is_any_drainable().await`, the inner `select!`), so
|
||||||
|
/// cancelling the loser arm of the phase-2 `select!` left N orphan
|
||||||
|
/// watchers parked on `notify.notified()`. Each held an
|
||||||
|
/// `Arc<…Inner>` and could steal a `notify_one()` permit from a
|
||||||
|
/// future batch's watcher, making that batch wait until the next
|
||||||
|
/// notify or its deadline. Wrapping in `AbortOnDrop` makes cleanup
|
||||||
|
/// happen on every exit path, including cancellation.
|
||||||
|
struct AbortOnDrop(tokio::task::JoinHandle<()>);
|
||||||
|
|
||||||
|
impl Drop for AbortOnDrop {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
self.0.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration) {
|
async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration) {
|
||||||
if inners.is_empty() {
|
if inners.is_empty() {
|
||||||
return;
|
return;
|
||||||
@@ -400,15 +421,15 @@ async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration
|
|||||||
|
|
||||||
// One watcher per session. Each loops until it observes real state
|
// One watcher per session. Each loops until it observes real state
|
||||||
// (eof set or buffer non-empty) before signaling — see the
|
// (eof set or buffer non-empty) before signaling — see the
|
||||||
// race-safety note on `wait_for_any_drainable` for why. We abort the
|
// race-safety note above. Watchers are held in a Vec of
|
||||||
// watchers on return; the only state they hold is a notify
|
// `AbortOnDrop`, so they're aborted on every exit path —
|
||||||
// subscription, so abort is clean.
|
// including cancellation by an outer `select!`.
|
||||||
let (tx, mut rx) = mpsc::channel::<()>(1);
|
let (tx, mut rx) = mpsc::channel::<()>(1);
|
||||||
let mut watchers = Vec::with_capacity(inners.len());
|
let mut _watchers: Vec<AbortOnDrop> = Vec::with_capacity(inners.len());
|
||||||
for inner in inners {
|
for inner in inners {
|
||||||
let inner = inner.clone();
|
let inner = inner.clone();
|
||||||
let tx = tx.clone();
|
let tx = tx.clone();
|
||||||
watchers.push(tokio::spawn(async move {
|
_watchers.push(AbortOnDrop(tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
inner.notify.notified().await;
|
inner.notify.notified().await;
|
||||||
if inner.eof.load(Ordering::Acquire) {
|
if inner.eof.load(Ordering::Acquire) {
|
||||||
@@ -423,7 +444,7 @@ async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration
|
|||||||
// notify, don't wake the caller.
|
// notify, don't wake the caller.
|
||||||
}
|
}
|
||||||
let _ = tx.try_send(());
|
let _ = tx.try_send(());
|
||||||
}));
|
})));
|
||||||
}
|
}
|
||||||
drop(tx);
|
drop(tx);
|
||||||
|
|
||||||
@@ -441,9 +462,9 @@ async fn wait_for_any_drainable(inners: &[Arc<SessionInner>], deadline: Duration
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for w in &watchers {
|
// No explicit abort loop: `_watchers`'s `AbortOnDrop` entries fire
|
||||||
w.abort();
|
// on the function returning here AND on the future being dropped
|
||||||
}
|
// mid-await by an outer `select!`.
|
||||||
}
|
}
|
||||||
|
|
||||||
/// True iff any session is currently drainable: its read buffer has
|
/// True iff any session is currently drainable: its read buffer has
|
||||||
@@ -481,12 +502,14 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// See `AbortOnDrop` and the comment on `wait_for_any_drainable`
|
||||||
|
// for why watchers must be aborted on every exit path.
|
||||||
let (tx, mut rx) = mpsc::channel::<()>(1);
|
let (tx, mut rx) = mpsc::channel::<()>(1);
|
||||||
let mut watchers = Vec::with_capacity(inners.len());
|
let mut _watchers: Vec<AbortOnDrop> = Vec::with_capacity(inners.len());
|
||||||
for inner in inners {
|
for inner in inners {
|
||||||
let inner = inner.clone();
|
let inner = inner.clone();
|
||||||
let tx = tx.clone();
|
let tx = tx.clone();
|
||||||
watchers.push(tokio::spawn(async move {
|
_watchers.push(AbortOnDrop(tokio::spawn(async move {
|
||||||
loop {
|
loop {
|
||||||
inner.notify.notified().await;
|
inner.notify.notified().await;
|
||||||
if inner.eof.load(Ordering::Acquire) {
|
if inner.eof.load(Ordering::Acquire) {
|
||||||
@@ -499,7 +522,7 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
|
|||||||
// prior batch. Loop back, don't wake the caller.
|
// prior batch. Loop back, don't wake the caller.
|
||||||
}
|
}
|
||||||
let _ = tx.try_send(());
|
let _ = tx.try_send(());
|
||||||
}));
|
})));
|
||||||
}
|
}
|
||||||
drop(tx);
|
drop(tx);
|
||||||
|
|
||||||
@@ -510,10 +533,6 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
|
|||||||
_ = tokio::time::sleep(deadline) => {}
|
_ = tokio::time::sleep(deadline) => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for w in &watchers {
|
|
||||||
w.abort();
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn is_any_udp_drainable(inners: &[Arc<UdpSessionInner>]) -> bool {
|
async fn is_any_udp_drainable(inners: &[Arc<UdpSessionInner>]) -> bool {
|
||||||
@@ -565,7 +584,10 @@ async fn wait_and_drain(session: &SessionInner, max_wait: Duration) -> (Vec<u8>,
|
|||||||
struct AppState {
|
struct AppState {
|
||||||
sessions: Arc<Mutex<HashMap<String, ManagedSession>>>,
|
sessions: Arc<Mutex<HashMap<String, ManagedSession>>>,
|
||||||
udp_sessions: Arc<Mutex<HashMap<String, ManagedUdpSession>>>,
|
udp_sessions: Arc<Mutex<HashMap<String, ManagedUdpSession>>>,
|
||||||
auth_key: String,
|
/// Shared, immutable after startup. `Arc<str>` so each `state.clone()`
|
||||||
|
/// — once per phase-1 spawn in the batch handler — is a refcount bump
|
||||||
|
/// instead of a fresh String allocation.
|
||||||
|
auth_key: Arc<str>,
|
||||||
/// Active probing defense: when false (default, production), bad
|
/// Active probing defense: when false (default, production), bad
|
||||||
/// AUTH_KEY responses are a generic-looking 404 with no JSON-shaped
|
/// AUTH_KEY responses are a generic-looking 404 with no JSON-shaped
|
||||||
/// "unauthorized" body — same as a static nginx 404. Active scanners
|
/// "unauthorized" body — same as a static nginx 404. Active scanners
|
||||||
@@ -650,7 +672,7 @@ async fn handle_tunnel(
|
|||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
Json(req): Json<TunnelRequest>,
|
Json(req): Json<TunnelRequest>,
|
||||||
) -> axum::response::Response {
|
) -> axum::response::Response {
|
||||||
if req.k != state.auth_key {
|
if req.k != *state.auth_key {
|
||||||
return decoy_or_unauthorized(state.diagnostic_mode);
|
return decoy_or_unauthorized(state.diagnostic_mode);
|
||||||
}
|
}
|
||||||
let resp: TunnelResponse = match req.op.as_str() {
|
let resp: TunnelResponse = match req.op.as_str() {
|
||||||
@@ -719,7 +741,7 @@ async fn handle_batch(
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if req.k != state.auth_key {
|
if req.k != *state.auth_key {
|
||||||
if state.diagnostic_mode {
|
if state.diagnostic_mode {
|
||||||
let resp = serde_json::to_vec(&BatchResponse {
|
let resp = serde_json::to_vec(&BatchResponse {
|
||||||
r: vec![TunnelResponse::error("unauthorized")],
|
r: vec![TunnelResponse::error("unauthorized")],
|
||||||
@@ -752,8 +774,13 @@ async fn handle_batch(
|
|||||||
// still fires from server-speaks-first ports and from the preread
|
// still fires from server-speaks-first ports and from the preread
|
||||||
// timeout fallback path.
|
// timeout fallback path.
|
||||||
let mut results: Vec<(usize, TunnelResponse)> = Vec::with_capacity(req.ops.len());
|
let mut results: Vec<(usize, TunnelResponse)> = Vec::with_capacity(req.ops.len());
|
||||||
let mut tcp_drains: Vec<(usize, String)> = Vec::new();
|
// Each drain entry carries the session's `Arc<…Inner>` alongside the
|
||||||
let mut udp_drains: Vec<(usize, String)> = Vec::new();
|
// sid. Phase 2 drains through the Arc directly so the global sessions
|
||||||
|
// map lock isn't held across the per-session read_buf / packets
|
||||||
|
// mutex acquisition — without this, every other batch (and every
|
||||||
|
// connect/close op) head-of-line-blocks behind the drain.
|
||||||
|
let mut tcp_drains: Vec<(usize, String, Arc<SessionInner>)> = Vec::new();
|
||||||
|
let mut udp_drains: Vec<(usize, String, Arc<UdpSessionInner>)> = Vec::new();
|
||||||
// True iff the batch contained any op that performed a real action
|
// True iff the batch contained any op that performed a real action
|
||||||
// upstream — a new connection or a non-empty data write. A batch of
|
// upstream — a new connection or a non-empty data write. A batch of
|
||||||
// only empty "data" / "udp_data" polls (and possibly closes) leaves
|
// only empty "data" / "udp_data" polls (and possibly closes) leaves
|
||||||
@@ -762,8 +789,8 @@ async fn handle_batch(
|
|||||||
|
|
||||||
enum NewConn {
|
enum NewConn {
|
||||||
Connect(TunnelResponse),
|
Connect(TunnelResponse),
|
||||||
ConnectData(Result<String, TunnelResponse>),
|
ConnectData(Result<(String, Arc<SessionInner>), TunnelResponse>),
|
||||||
UdpOpen(Result<String, TunnelResponse>),
|
UdpOpen(Result<(String, Arc<UdpSessionInner>), TunnelResponse>),
|
||||||
}
|
}
|
||||||
let mut new_conn_jobs: JoinSet<(usize, NewConn)> = JoinSet::new();
|
let mut new_conn_jobs: JoinSet<(usize, NewConn)> = JoinSet::new();
|
||||||
|
|
||||||
@@ -785,13 +812,11 @@ async fn handle_batch(
|
|||||||
let port = op.port;
|
let port = op.port;
|
||||||
let d = op.d.clone();
|
let d = op.d.clone();
|
||||||
new_conn_jobs.spawn(async move {
|
new_conn_jobs.spawn(async move {
|
||||||
// Drop the returned Arc<SessionInner>: phase 2 below
|
// Keep the returned Arc<SessionInner>: phase 2 drains
|
||||||
// re-looks up each sid under one sessions-map lock,
|
// through it directly, so the global sessions map
|
||||||
// which is cheap. The Arc return is a convenience for
|
// lock doesn't have to be held across the per-session
|
||||||
// the single-op path only.
|
// read_buf.lock().await.
|
||||||
let r = handle_connect_data_phase1(&state, host, port, d)
|
let r = handle_connect_data_phase1(&state, host, port, d).await;
|
||||||
.await
|
|
||||||
.map(|(sid, _inner)| sid);
|
|
||||||
(i, NewConn::ConnectData(r))
|
(i, NewConn::ConnectData(r))
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -808,9 +833,7 @@ async fn handle_batch(
|
|||||||
let port = op.port;
|
let port = op.port;
|
||||||
let d = op.d.clone();
|
let d = op.d.clone();
|
||||||
new_conn_jobs.spawn(async move {
|
new_conn_jobs.spawn(async move {
|
||||||
let r = handle_udp_open_phase1(&state, host, port, d)
|
let r = handle_udp_open_phase1(&state, host, port, d).await;
|
||||||
.await
|
|
||||||
.map(|(sid, _inner)| sid);
|
|
||||||
(i, NewConn::UdpOpen(r))
|
(i, NewConn::UdpOpen(r))
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -820,26 +843,46 @@ async fn handle_batch(
|
|||||||
_ => { results.push((i, TunnelResponse::error("missing sid"))); continue; }
|
_ => { results.push((i, TunnelResponse::error("missing sid"))); continue; }
|
||||||
};
|
};
|
||||||
|
|
||||||
// Write outbound data
|
// Clone the inner under the map lock and release it
|
||||||
|
// before any await. The previous shape held the global
|
||||||
|
// sessions map across last_active.lock(), writer.lock(),
|
||||||
|
// write_all, and flush — head-of-line-blocking every
|
||||||
|
// other batch and connect/close op for the duration of
|
||||||
|
// a single upstream write. The udp_data branch below
|
||||||
|
// already does the right thing; this matches it.
|
||||||
|
let inner = {
|
||||||
let sessions = state.sessions.lock().await;
|
let sessions = state.sessions.lock().await;
|
||||||
if let Some(session) = sessions.get(&sid) {
|
sessions.get(&sid).map(|s| s.inner.clone())
|
||||||
*session.inner.last_active.lock().await = Instant::now();
|
};
|
||||||
|
if let Some(inner) = inner {
|
||||||
|
*inner.last_active.lock().await = Instant::now();
|
||||||
if let Some(ref data_b64) = op.d {
|
if let Some(ref data_b64) = op.d {
|
||||||
if !data_b64.is_empty() {
|
if !data_b64.is_empty() {
|
||||||
had_writes_or_connects = true;
|
// Decode first; only count this op as a real
|
||||||
if let Ok(bytes) = B64.decode(data_b64) {
|
// write (and demote the batch out of long-poll)
|
||||||
|
// after a successful non-empty decode. Mirrors
|
||||||
|
// the udp_data branch and avoids silently
|
||||||
|
// dropping bytes on bad base64.
|
||||||
|
let bytes = match B64.decode(data_b64) {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => {
|
||||||
|
results.push((
|
||||||
|
i,
|
||||||
|
TunnelResponse::error(format!("bad base64: {}", e)),
|
||||||
|
));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
if !bytes.is_empty() {
|
if !bytes.is_empty() {
|
||||||
let mut w = session.inner.writer.lock().await;
|
had_writes_or_connects = true;
|
||||||
|
let mut w = inner.writer.lock().await;
|
||||||
let _ = w.write_all(&bytes).await;
|
let _ = w.write_all(&bytes).await;
|
||||||
let _ = w.flush().await;
|
let _ = w.flush().await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
tcp_drains.push((i, sid, inner));
|
||||||
drop(sessions);
|
|
||||||
tcp_drains.push((i, sid));
|
|
||||||
} else {
|
} else {
|
||||||
drop(sessions);
|
|
||||||
results.push((i, eof_response(sid)));
|
results.push((i, eof_response(sid)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -881,7 +924,7 @@ async fn handle_batch(
|
|||||||
if had_uplink {
|
if had_uplink {
|
||||||
*inner.last_active.lock().await = Instant::now();
|
*inner.last_active.lock().await = Instant::now();
|
||||||
}
|
}
|
||||||
udp_drains.push((i, sid));
|
udp_drains.push((i, sid, inner));
|
||||||
} else {
|
} else {
|
||||||
results.push((i, eof_response(sid)));
|
results.push((i, eof_response(sid)));
|
||||||
}
|
}
|
||||||
@@ -902,9 +945,13 @@ async fn handle_batch(
|
|||||||
while let Some(join) = new_conn_jobs.join_next().await {
|
while let Some(join) = new_conn_jobs.join_next().await {
|
||||||
match join {
|
match join {
|
||||||
Ok((i, NewConn::Connect(r))) => results.push((i, r)),
|
Ok((i, NewConn::Connect(r))) => results.push((i, r)),
|
||||||
Ok((i, NewConn::ConnectData(Ok(sid)))) => tcp_drains.push((i, sid)),
|
Ok((i, NewConn::ConnectData(Ok((sid, inner))))) => {
|
||||||
|
tcp_drains.push((i, sid, inner));
|
||||||
|
}
|
||||||
Ok((i, NewConn::ConnectData(Err(r)))) => results.push((i, r)),
|
Ok((i, NewConn::ConnectData(Err(r)))) => results.push((i, r)),
|
||||||
Ok((i, NewConn::UdpOpen(Ok(sid)))) => udp_drains.push((i, sid)),
|
Ok((i, NewConn::UdpOpen(Ok((sid, inner))))) => {
|
||||||
|
udp_drains.push((i, sid, inner));
|
||||||
|
}
|
||||||
Ok((i, NewConn::UdpOpen(Err(r)))) => results.push((i, r)),
|
Ok((i, NewConn::UdpOpen(Err(r)))) => results.push((i, r)),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
tracing::error!("new-connection task panicked: {}", e);
|
tracing::error!("new-connection task panicked: {}", e);
|
||||||
@@ -930,34 +977,38 @@ async fn handle_batch(
|
|||||||
LONGPOLL_DEADLINE
|
LONGPOLL_DEADLINE
|
||||||
};
|
};
|
||||||
|
|
||||||
let tcp_inners: Vec<Arc<SessionInner>> = {
|
// Phase 1 already gave us each session's Arc<…Inner>, so we
|
||||||
let sessions = state.sessions.lock().await;
|
// don't need to re-acquire the sessions map lock here. Cloning
|
||||||
tcp_drains
|
// the Arc is just a refcount bump.
|
||||||
.iter()
|
let tcp_inners: Vec<Arc<SessionInner>> =
|
||||||
.filter_map(|(_, sid)| sessions.get(sid).map(|s| s.inner.clone()))
|
tcp_drains.iter().map(|(_, _, inner)| inner.clone()).collect();
|
||||||
.collect()
|
let udp_inners: Vec<Arc<UdpSessionInner>> =
|
||||||
};
|
udp_drains.iter().map(|(_, _, inner)| inner.clone()).collect();
|
||||||
let udp_inners: Vec<Arc<UdpSessionInner>> = {
|
|
||||||
let sessions = state.udp_sessions.lock().await;
|
|
||||||
udp_drains
|
|
||||||
.iter()
|
|
||||||
.filter_map(|(_, sid)| sessions.get(sid).map(|s| s.inner.clone()))
|
|
||||||
.collect()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Wait for either side to wake. Running both concurrently means
|
// Wake on whichever side has work first. The previous
|
||||||
// a TCP-only batch isn't slowed by a stale UDP watch list, and
|
// `tokio::join!` was conjunctive — a TCP burst still paid the
|
||||||
// vice versa.
|
// UDP deadline in mixed batches because the UDP waiter had to
|
||||||
tokio::join!(
|
// elapse too. `wait_for_*_drainable` short-circuits on an empty
|
||||||
wait_for_any_drainable(&tcp_inners, deadline),
|
// slice, so we have to skip the empty side; otherwise its
|
||||||
wait_for_any_udp_drainable(&udp_inners, deadline),
|
// instant return would fire the select arm before the other
|
||||||
);
|
// side ever got a chance to wait.
|
||||||
|
match (tcp_inners.is_empty(), udp_inners.is_empty()) {
|
||||||
|
(true, true) => {}
|
||||||
|
(false, true) => wait_for_any_drainable(&tcp_inners, deadline).await,
|
||||||
|
(true, false) => wait_for_any_udp_drainable(&udp_inners, deadline).await,
|
||||||
|
(false, false) => {
|
||||||
|
tokio::select! {
|
||||||
|
_ = wait_for_any_drainable(&tcp_inners, deadline) => {}
|
||||||
|
_ = wait_for_any_udp_drainable(&udp_inners, deadline) => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if had_writes_or_connects {
|
if had_writes_or_connects {
|
||||||
// Adaptive settle: keep waiting in steps while new data
|
// Adaptive settle: keep waiting in steps while new data
|
||||||
// keeps arriving. Break when:
|
// keeps arriving. Break when:
|
||||||
// 1. No new data arrived in the last step (burst is over)
|
// 1. No new data arrived in the last step (burst is over)
|
||||||
// 2. 500ms max reached
|
// 2. STRAGGLER_SETTLE_MAX reached
|
||||||
let settle_end = Instant::now() + STRAGGLER_SETTLE_MAX;
|
let settle_end = Instant::now() + STRAGGLER_SETTLE_MAX;
|
||||||
let mut prev_tcp_bytes: usize = 0;
|
let mut prev_tcp_bytes: usize = 0;
|
||||||
let mut prev_udp_pkts: usize = 0;
|
let mut prev_udp_pkts: usize = 0;
|
||||||
@@ -997,53 +1048,56 @@ async fn handle_batch(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ---- TCP drain ----
|
// ---- TCP drain ----
|
||||||
if !tcp_drains.is_empty() {
|
// Drain through each session's already-cloned Arc so the global
|
||||||
let sessions = state.sessions.lock().await;
|
// sessions map lock isn't held across the per-session
|
||||||
for (i, sid) in &tcp_drains {
|
// read_buf.lock().await.
|
||||||
if let Some(session) = sessions.get(sid) {
|
//
|
||||||
let (data, eof) = drain_now(&session.inner).await;
|
// Cleanup is driven off `drain_now`'s returned `eof`, NOT the
|
||||||
|
// raw `inner.eof` atomic. When the buffer exceeds
|
||||||
|
// `TCP_DRAIN_MAX_BYTES`, `drain_now` deliberately returns
|
||||||
|
// `eof = false` and leaves the tail in the buffer so the
|
||||||
|
// client can pick it up on the next poll. The previous cleanup
|
||||||
|
// read the atomic directly, so on a high-throughput session
|
||||||
|
// that closed mid-burst (issue #460-style) it would remove the
|
||||||
|
// session and abort the reader_task with the tail still
|
||||||
|
// buffered, dropping those bytes.
|
||||||
|
let mut tcp_eof_sids: Vec<String> = Vec::new();
|
||||||
|
for (i, sid, inner) in &tcp_drains {
|
||||||
|
let (data, eof) = drain_now(inner).await;
|
||||||
|
if eof {
|
||||||
|
tcp_eof_sids.push(sid.clone());
|
||||||
|
}
|
||||||
results.push((*i, tcp_drain_response(sid.clone(), data, eof)));
|
results.push((*i, tcp_drain_response(sid.clone(), data, eof)));
|
||||||
} else {
|
|
||||||
results.push((*i, eof_response(sid.clone())));
|
|
||||||
}
|
}
|
||||||
}
|
if !tcp_eof_sids.is_empty() {
|
||||||
drop(sessions);
|
|
||||||
|
|
||||||
// Clean up eof TCP sessions.
|
|
||||||
let mut sessions = state.sessions.lock().await;
|
let mut sessions = state.sessions.lock().await;
|
||||||
for (_, sid) in &tcp_drains {
|
for sid in &tcp_eof_sids {
|
||||||
if let Some(s) = sessions.get(sid) {
|
|
||||||
if s.inner.eof.load(Ordering::Acquire) {
|
|
||||||
if let Some(s) = sessions.remove(sid) {
|
if let Some(s) = sessions.remove(sid) {
|
||||||
s.reader_handle.abort();
|
s.reader_handle.abort();
|
||||||
tracing::info!("session {} closed by remote (batch)", sid);
|
tracing::info!("session {} closed by remote (batch)", sid);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---- UDP drain ----
|
// ---- UDP drain ----
|
||||||
if !udp_drains.is_empty() {
|
// Same shape as TCP. `drain_udp_now` currently drains the full
|
||||||
{
|
// queue with no per-batch cap, so its returned `eof` already
|
||||||
let sessions = state.udp_sessions.lock().await;
|
// matches the atomic — driving cleanup off the drain return
|
||||||
for (i, sid) in &udp_drains {
|
// is future-proofing: if a UDP per-batch packet cap is ever
|
||||||
if let Some(session) = sessions.get(sid) {
|
// added (mirroring `TCP_DRAIN_MAX_BYTES`), the same data-loss
|
||||||
let (packets, eof) = drain_udp_now(&session.inner).await;
|
// trap that motivated the TCP-side fix reappears, and tracking
|
||||||
|
// eof from the drain return rather than the atomic catches it.
|
||||||
|
let mut udp_eof_sids: Vec<String> = Vec::new();
|
||||||
|
for (i, sid, inner) in &udp_drains {
|
||||||
|
let (packets, eof) = drain_udp_now(inner).await;
|
||||||
|
if eof {
|
||||||
|
udp_eof_sids.push(sid.clone());
|
||||||
|
}
|
||||||
results.push((*i, udp_drain_response(sid.clone(), packets, eof)));
|
results.push((*i, udp_drain_response(sid.clone(), packets, eof)));
|
||||||
} else {
|
|
||||||
results.push((*i, eof_response(sid.clone())));
|
|
||||||
}
|
}
|
||||||
}
|
if !udp_eof_sids.is_empty() {
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up eof UDP sessions so a future batch with the same
|
|
||||||
// sid gets the "session not found" eof immediately rather
|
|
||||||
// than re-checking the (already-stale) eof flag.
|
|
||||||
let mut sessions = state.udp_sessions.lock().await;
|
let mut sessions = state.udp_sessions.lock().await;
|
||||||
for (_, sid) in &udp_drains {
|
for sid in &udp_eof_sids {
|
||||||
if let Some(s) = sessions.get(sid) {
|
|
||||||
if s.inner.eof.load(Ordering::Acquire) {
|
|
||||||
if let Some(s) = sessions.remove(sid) {
|
if let Some(s) = sessions.remove(sid) {
|
||||||
s.reader_handle.abort();
|
s.reader_handle.abort();
|
||||||
tracing::info!("udp session {} closed by remote (batch)", sid);
|
tracing::info!("udp session {} closed by remote (batch)", sid);
|
||||||
@@ -1051,8 +1105,6 @@ async fn handle_batch(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sort results by original index and build response
|
// Sort results by original index and build response
|
||||||
results.sort_by_key(|(i, _)| *i);
|
results.sort_by_key(|(i, _)| *i);
|
||||||
@@ -1149,11 +1201,12 @@ async fn handle_connect(state: &AppState, host: Option<String>, port: Option<u16
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Open a session and write the client's first bytes in one round trip.
|
/// Open a session and write the client's first bytes in one round trip.
|
||||||
/// Returns the new sid plus an `Arc<SessionInner>` so unary callers
|
/// Returns the new sid plus an `Arc<SessionInner>`. Both callers keep
|
||||||
/// (`handle_connect_data_single`) can drain the first response without a
|
/// the Arc: the unary path (`handle_connect_data_single`) uses it to
|
||||||
/// second sessions-map lookup. The batch caller drops the Arc — it takes
|
/// drain the first response without a second sessions-map lookup, and
|
||||||
/// a single lock across all drain-bound sessions in phase 2, which is
|
/// the batch path threads it into `tcp_drains` so phase-2 drain runs
|
||||||
/// cheaper than the Arc plumbing would be.
|
/// without holding the global sessions map lock across the per-session
|
||||||
|
/// `read_buf.lock().await`.
|
||||||
async fn handle_connect_data_phase1(
|
async fn handle_connect_data_phase1(
|
||||||
state: &AppState,
|
state: &AppState,
|
||||||
host: Option<String>,
|
host: Option<String>,
|
||||||
@@ -1274,19 +1327,27 @@ async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<
|
|||||||
Some(s) if !s.is_empty() => s,
|
Some(s) if !s.is_empty() => s,
|
||||||
_ => return TunnelResponse::error("missing sid"),
|
_ => return TunnelResponse::error("missing sid"),
|
||||||
};
|
};
|
||||||
|
// Clone the inner Arc under the global sessions map lock and release
|
||||||
|
// the map lock before any await. The previous shape held the map
|
||||||
|
// across last_active.lock(), writer.lock(), write_all, flush, AND
|
||||||
|
// wait_and_drain — up to 5 s of head-of-line blocking on every other
|
||||||
|
// single-op or batch request. Mirrors the batch-handler "data" path.
|
||||||
|
let inner = {
|
||||||
let sessions = state.sessions.lock().await;
|
let sessions = state.sessions.lock().await;
|
||||||
let session = match sessions.get(&sid) {
|
sessions.get(&sid).map(|s| s.inner.clone())
|
||||||
Some(s) => s,
|
};
|
||||||
|
let inner = match inner {
|
||||||
|
Some(i) => i,
|
||||||
None => return TunnelResponse::error("unknown session"),
|
None => return TunnelResponse::error("unknown session"),
|
||||||
};
|
};
|
||||||
*session.inner.last_active.lock().await = Instant::now();
|
*inner.last_active.lock().await = Instant::now();
|
||||||
if let Some(ref data_b64) = data {
|
if let Some(ref data_b64) = data {
|
||||||
if !data_b64.is_empty() {
|
if !data_b64.is_empty() {
|
||||||
if let Ok(bytes) = B64.decode(data_b64) {
|
if let Ok(bytes) = B64.decode(data_b64) {
|
||||||
if !bytes.is_empty() {
|
if !bytes.is_empty() {
|
||||||
let mut w = session.inner.writer.lock().await;
|
let mut w = inner.writer.lock().await;
|
||||||
if let Err(e) = w.write_all(&bytes).await {
|
if let Err(e) = w.write_all(&bytes).await {
|
||||||
drop(w); drop(sessions);
|
drop(w);
|
||||||
state.sessions.lock().await.remove(&sid);
|
state.sessions.lock().await.remove(&sid);
|
||||||
return TunnelResponse::error(format!("write failed: {}", e));
|
return TunnelResponse::error(format!("write failed: {}", e));
|
||||||
}
|
}
|
||||||
@@ -1295,8 +1356,7 @@ async fn handle_data_single(state: &AppState, sid: Option<String>, data: Option<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let (data, eof) = wait_and_drain(&session.inner, Duration::from_secs(5)).await;
|
let (data, eof) = wait_and_drain(&inner, Duration::from_secs(5)).await;
|
||||||
drop(sessions);
|
|
||||||
if eof {
|
if eof {
|
||||||
if let Some(s) = state.sessions.lock().await.remove(&sid) {
|
if let Some(s) = state.sessions.lock().await.remove(&sid) {
|
||||||
s.reader_handle.abort();
|
s.reader_handle.abort();
|
||||||
@@ -1449,7 +1509,12 @@ async fn main() {
|
|||||||
before exposing this tunnel-node to the public internet."
|
before exposing this tunnel-node to the public internet."
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
let state = AppState { sessions, udp_sessions, auth_key, diagnostic_mode };
|
let state = AppState {
|
||||||
|
sessions,
|
||||||
|
udp_sessions,
|
||||||
|
auth_key: Arc::from(auth_key),
|
||||||
|
diagnostic_mode,
|
||||||
|
};
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route("/tunnel", post(handle_tunnel))
|
.route("/tunnel", post(handle_tunnel))
|
||||||
@@ -2249,4 +2314,151 @@ mod tests {
|
|||||||
assert_eq!(r.len(), 1);
|
assert_eq!(r.len(), 1);
|
||||||
assert_eq!(r[0]["eof"], serde_json::Value::Bool(true));
|
assert_eq!(r[0]["eof"], serde_json::Value::Bool(true));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Regression for the cleanup-correctness fix. Previously, the
|
||||||
|
/// batch handler reaped any session whose `inner.eof` atomic was
|
||||||
|
/// set, even when `drain_now` had withheld eof to keep tail bytes
|
||||||
|
/// buffered (i.e. the buffer exceeded `TCP_DRAIN_MAX_BYTES`).
|
||||||
|
/// Reaping aborted the reader_task and dropped the tail. Cleanup
|
||||||
|
/// is now driven off the drain's returned `eof`, so an over-cap
|
||||||
|
/// buffer + atomic eof keeps the session alive through the first
|
||||||
|
/// poll and only reaps on the drain that actually returns eof.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn batch_keeps_over_cap_session_until_tail_is_drained() {
|
||||||
|
use axum::body::Bytes;
|
||||||
|
use axum::extract::State;
|
||||||
|
|
||||||
|
let state = fresh_state();
|
||||||
|
let inner = fake_inner().await;
|
||||||
|
// Prime an over-cap buffer + raw eof. drain_now will return
|
||||||
|
// TCP_DRAIN_MAX_BYTES bytes with eof=false; the previous
|
||||||
|
// cleanup would still reap because it read inner.eof directly.
|
||||||
|
inner
|
||||||
|
.read_buf
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.resize(TCP_DRAIN_MAX_BYTES + 4096, 0u8);
|
||||||
|
inner.eof.store(true, Ordering::Release);
|
||||||
|
|
||||||
|
let sid = "over-cap-sid".to_string();
|
||||||
|
state.sessions.lock().await.insert(
|
||||||
|
sid.clone(),
|
||||||
|
ManagedSession {
|
||||||
|
inner: inner.clone(),
|
||||||
|
reader_handle: tokio::spawn(async {}),
|
||||||
|
udpgw_handle: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"k": "test-key",
|
||||||
|
"ops": [{"op": "data", "sid": &sid}]
|
||||||
|
})
|
||||||
|
.to_string();
|
||||||
|
let _resp = handle_batch(State(state.clone()), Bytes::from(body))
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
|
|
||||||
|
// First poll: session must still be in the map, tail intact.
|
||||||
|
// The previous code reaped here and dropped the 4096 tail bytes.
|
||||||
|
{
|
||||||
|
let sessions = state.sessions.lock().await;
|
||||||
|
let s = sessions.get(&sid).expect(
|
||||||
|
"session removed despite tail bytes still buffered; \
|
||||||
|
drain_now returned eof=false but cleanup ignored that \
|
||||||
|
and read inner.eof directly",
|
||||||
|
);
|
||||||
|
let remaining = s.inner.read_buf.lock().await.len();
|
||||||
|
assert_eq!(remaining, 4096, "tail must be preserved for next drain");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second poll: drain_now sees buf.len() ≤ cap AND raw_eof,
|
||||||
|
// so returns eof=true. Cleanup runs and the session is reaped.
|
||||||
|
let body2 = serde_json::json!({
|
||||||
|
"k": "test-key",
|
||||||
|
"ops": [{"op": "data", "sid": &sid}]
|
||||||
|
})
|
||||||
|
.to_string();
|
||||||
|
let _resp2 = handle_batch(State(state.clone()), Bytes::from(body2))
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
!state.sessions.lock().await.contains_key(&sid),
|
||||||
|
"session should be reaped on the drain that returns eof=true",
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Regression for the `tokio::join!` → `tokio::select!` mixed-drain
|
||||||
|
/// fix. Before the change, a TCP-ready / UDP-idle pure-poll batch
|
||||||
|
/// paid the full UDP `LONGPOLL_DEADLINE` (15 s) because the join
|
||||||
|
/// was conjunctive — both arms had to complete. Under select! the
|
||||||
|
/// TCP wake returns the response promptly even though UDP is
|
||||||
|
/// quiet. The bound is loose (1 s) on purpose: real elapsed is
|
||||||
|
/// in the millisecond range, but the prior bug would have
|
||||||
|
/// triggered the test timeout instead of the assert.
|
||||||
|
#[tokio::test]
|
||||||
|
async fn batch_tcp_ready_does_not_pay_udp_longpoll_deadline() {
|
||||||
|
use axum::body::Bytes;
|
||||||
|
use axum::extract::State;
|
||||||
|
|
||||||
|
let state = fresh_state();
|
||||||
|
|
||||||
|
// TCP session with bytes already buffered → immediately drainable.
|
||||||
|
let tcp_inner = fake_inner().await;
|
||||||
|
tcp_inner
|
||||||
|
.read_buf
|
||||||
|
.lock()
|
||||||
|
.await
|
||||||
|
.extend_from_slice(b"ready");
|
||||||
|
let tcp_sid = "tcp-sid".to_string();
|
||||||
|
state.sessions.lock().await.insert(
|
||||||
|
tcp_sid.clone(),
|
||||||
|
ManagedSession {
|
||||||
|
inner: tcp_inner,
|
||||||
|
reader_handle: tokio::spawn(async {}),
|
||||||
|
udpgw_handle: None,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
// Idle UDP session — never wakes. Real upstream so udp_open
|
||||||
|
// succeeds; we just never send anything to it.
|
||||||
|
let udp_target = UdpSocket::bind(("127.0.0.1", 0)).await.unwrap();
|
||||||
|
let udp_port = udp_target.local_addr().unwrap().port();
|
||||||
|
let (udp_sid, _udp_inner) = handle_udp_open_phase1(
|
||||||
|
&state,
|
||||||
|
Some("127.0.0.1".into()),
|
||||||
|
Some(udp_port),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("udp open");
|
||||||
|
|
||||||
|
// Pure-poll batch (no `d` payload) → had_writes_or_connects =
|
||||||
|
// false → deadline = LONGPOLL_DEADLINE (15 s). Under the
|
||||||
|
// previous tokio::join! wait, the UDP arm would have held the
|
||||||
|
// response open for the full window even though TCP was
|
||||||
|
// already drainable.
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"k": "test-key",
|
||||||
|
"ops": [
|
||||||
|
{"op": "data", "sid": &tcp_sid},
|
||||||
|
{"op": "udp_data", "sid": &udp_sid},
|
||||||
|
]
|
||||||
|
})
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let t0 = Instant::now();
|
||||||
|
let _resp = handle_batch(State(state.clone()), Bytes::from(body))
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
|
let elapsed = t0.elapsed();
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
elapsed < Duration::from_secs(1),
|
||||||
|
"TCP-ready / UDP-idle pure-poll batch must not pay \
|
||||||
|
LONGPOLL_DEADLINE; elapsed={:?}",
|
||||||
|
elapsed,
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user