fix(udp): surface upstream eof + bound sessions and payload size

This commit is contained in:
dazzling-no-more
2026-04-25 15:14:33 +04:00
parent 40c2b6c509
commit bf6fab31ab
2 changed files with 244 additions and 27 deletions
+134 -16
View File
@@ -123,6 +123,11 @@ struct UdpSessionInner {
packets: Mutex<VecDeque<Vec<u8>>>,
last_active: Mutex<Instant>,
notify: Notify,
/// Set when the upstream socket dies (recv error). Mirrors TCP's
/// `eof`: once true, subsequent batch drains return `eof: Some(true)`
/// so the proxy-side session task knows to exit instead of polling
/// a zombie session until the 120 s idle reaper kills it.
eof: AtomicBool,
/// Total datagrams dropped because the queue hit `UDP_QUEUE_LIMIT`.
/// Surfaced via tracing so operators can correlate "choppy call"
/// reports with relay backpressure.
@@ -209,6 +214,7 @@ async fn create_udp_session(host: &str, port: u16) -> std::io::Result<ManagedUdp
packets: Mutex::new(VecDeque::with_capacity(UDP_QUEUE_LIMIT)),
last_active: Mutex::new(Instant::now()),
notify: Notify::new(),
eof: AtomicBool::new(false),
queue_drops: AtomicU64::new(0),
});
@@ -254,7 +260,16 @@ async fn udp_reader_task(socket: Arc<UdpSocket>, session: Arc<UdpSessionInner>)
*session.last_active.lock().await = Instant::now();
session.notify.notify_one();
}
Err(_) => break,
Err(e) => {
// Upstream socket died (ICMP unreachable on a connected
// socket, container netns torn down, etc.). Surface eof
// so the proxy-side session task can exit on its next
// poll instead of looping until the idle reaper.
tracing::debug!("udp upstream recv error: {} — marking session eof", e);
session.eof.store(true, Ordering::Release);
session.notify.notify_one();
break;
}
}
}
}
@@ -362,14 +377,19 @@ async fn is_any_drainable(inners: &[Arc<SessionInner>]) -> bool {
}
/// Drain whatever UDP datagrams are currently queued — no waiting.
async fn drain_udp_now(session: &UdpSessionInner) -> Vec<Vec<u8>> {
/// Returns the eof flag alongside packets so the batch handler can
/// surface upstream-socket death without an extra round-trip.
async fn drain_udp_now(session: &UdpSessionInner) -> (Vec<Vec<u8>>, bool) {
let mut packets = session.packets.lock().await;
packets.drain(..).collect()
let drained: Vec<Vec<u8>> = packets.drain(..).collect();
let eof = session.eof.load(Ordering::Acquire);
(drained, eof)
}
/// UDP analogue of `wait_for_any_drainable`. Wakes when any session has
/// at least one queued packet. Same race-safety contract: watchers
/// self-filter against observable state to ignore stale permits.
/// at least one queued packet OR has been marked eof. Same race-safety
/// contract: watchers self-filter against observable state to ignore
/// stale permits.
async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: Duration) {
if inners.is_empty() {
return;
@@ -383,6 +403,9 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
watchers.push(tokio::spawn(async move {
loop {
inner.notify.notified().await;
if inner.eof.load(Ordering::Acquire) {
break;
}
if !inner.packets.lock().await.is_empty() {
break;
}
@@ -409,6 +432,9 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
async fn is_any_udp_drainable(inners: &[Arc<UdpSessionInner>]) -> bool {
for inner in inners {
if inner.eof.load(Ordering::Acquire) {
return true;
}
if !inner.packets.lock().await.is_empty() {
return true;
}
@@ -642,7 +668,13 @@ async fn handle_batch(
});
}
"udp_open" => {
had_writes_or_connects = true;
// An open *with* an initial datagram is real upstream
// work; an open without one (rare — current proxy
// never invokes it that way) is just resource alloc
// and shouldn't suppress long-poll on sibling polls.
if op.d.as_deref().map(|d| !d.is_empty()).unwrap_or(false) {
had_writes_or_connects = true;
}
let state = state.clone();
let host = op.host.clone();
let port = op.port;
@@ -830,13 +862,30 @@ async fn handle_batch(
// ---- UDP drain ----
if !udp_drains.is_empty() {
let sessions = state.udp_sessions.lock().await;
for (i, sid) in &udp_drains {
if let Some(session) = sessions.get(sid) {
let packets = drain_udp_now(&session.inner).await;
results.push((*i, udp_drain_response(sid.clone(), packets)));
} else {
results.push((*i, eof_response(sid.clone())));
{
let sessions = state.udp_sessions.lock().await;
for (i, sid) in &udp_drains {
if let Some(session) = sessions.get(sid) {
let (packets, eof) = drain_udp_now(&session.inner).await;
results.push((*i, udp_drain_response(sid.clone(), packets, eof)));
} else {
results.push((*i, eof_response(sid.clone())));
}
}
}
// Clean up eof UDP sessions so a future batch with the same
// sid gets the "session not found" eof immediately rather
// than re-checking the (already-stale) eof flag.
let mut sessions = state.udp_sessions.lock().await;
for (_, sid) in &udp_drains {
if let Some(s) = sessions.get(sid) {
if s.inner.eof.load(Ordering::Acquire) {
if let Some(s) = sessions.remove(sid) {
s.reader_handle.abort();
tracing::info!("udp session {} closed by remote (batch)", sid);
}
}
}
}
}
@@ -863,7 +912,7 @@ fn tcp_drain_response(sid: String, data: Vec<u8>, eof: bool) -> TunnelResponse {
}
}
fn udp_drain_response(sid: String, packets: Vec<Vec<u8>>) -> TunnelResponse {
fn udp_drain_response(sid: String, packets: Vec<Vec<u8>>, eof: bool) -> TunnelResponse {
let pkts = if packets.is_empty() {
None
} else {
@@ -873,7 +922,7 @@ fn udp_drain_response(sid: String, packets: Vec<Vec<u8>>) -> TunnelResponse {
sid: Some(sid),
d: None,
pkts,
eof: Some(false),
eof: Some(eof),
e: None,
code: None,
}
@@ -1767,8 +1816,9 @@ mod tests {
assert!(state.udp_sessions.lock().await.contains_key(&sid));
wait_for_any_udp_drainable(std::slice::from_ref(&inner), Duration::from_secs(2)).await;
let packets = drain_udp_now(&inner).await;
let (packets, eof) = drain_udp_now(&inner).await;
assert_eq!(packets, vec![b"ECHO: ping".to_vec()]);
assert!(!eof);
}
/// When the upstream sends faster than the relay drains, the queue
@@ -1871,4 +1921,72 @@ mod tests {
let decoded = B64.decode(tcp_d).unwrap();
assert_eq!(&decoded[..], b"DELAYED");
}
/// When the upstream UDP socket dies (recv error), the reader_task
/// must mark the session eof so subsequent batches return
/// `eof: true` instead of looping the proxy on a zombie session.
#[tokio::test]
async fn udp_drain_surfaces_upstream_eof() {
let inner = Arc::new(UdpSessionInner {
socket: Arc::new(UdpSocket::bind(("127.0.0.1", 0)).await.unwrap()),
packets: Mutex::new(VecDeque::new()),
last_active: Mutex::new(Instant::now()),
notify: Notify::new(),
eof: AtomicBool::new(false),
queue_drops: AtomicU64::new(0),
});
// Healthy state: drain reports no eof.
let (pkts, eof) = drain_udp_now(&inner).await;
assert!(pkts.is_empty());
assert!(!eof);
// Simulate the failure path udp_reader_task takes on socket err.
inner.eof.store(true, Ordering::Release);
inner.notify.notify_one();
let (pkts, eof) = drain_udp_now(&inner).await;
assert!(pkts.is_empty());
assert!(eof, "drain should surface eof once the reader marks it");
// wait_for_any_udp_drainable also wakes immediately on eof.
let t0 = Instant::now();
wait_for_any_udp_drainable(std::slice::from_ref(&inner), Duration::from_secs(5)).await;
assert!(
t0.elapsed() < Duration::from_millis(100),
"eof should short-circuit the wait, took {:?}",
t0.elapsed()
);
// The `udp_drain_response` helper threads eof into `eof: Some(true)`.
let resp = udp_drain_response("zombie".into(), pkts, eof);
assert_eq!(resp.eof, Some(true));
assert!(resp.pkts.is_none());
}
/// A batch that targets a UDP session reaped by the cleanup task
/// (or removed via close) returns `eof: true` so the proxy task
/// exits its select loop instead of polling a zombie.
#[tokio::test]
async fn udp_data_for_missing_session_returns_eof() {
use axum::body::Bytes;
use axum::extract::State;
let state = fresh_state();
let body = serde_json::json!({
"k": "test-key",
"ops": [
{"op": "udp_data", "sid": "does-not-exist"},
]
})
.to_string();
let resp = handle_batch(State(state.clone()), Bytes::from(body))
.await
.into_response();
let (_parts, body) = resp.into_parts();
let body_bytes = axum::body::to_bytes(body, 64 * 1024).await.unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
let r = parsed["r"].as_array().unwrap();
assert_eq!(r.len(), 1);
assert_eq!(r[0]["eof"], serde_json::Value::Bool(true));
}
}