fix(udp): surface upstream eof + bound sessions and payload size

This commit is contained in:
dazzling-no-more
2026-04-25 15:14:33 +04:00
parent 40c2b6c509
commit bf6fab31ab
2 changed files with 244 additions and 27 deletions
+110 -11
View File
@@ -1,5 +1,5 @@
use std::collections::HashMap; use std::collections::{HashMap, VecDeque};
use std::net::{IpAddr, Ipv4Addr, SocketAddr}; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
@@ -604,6 +604,22 @@ const UDP_INITIAL_POLL_DELAY: Duration = Duration::from_millis(500);
/// so an idle UDP destination costs roughly one batch slot every 35 s. /// so an idle UDP destination costs roughly one batch slot every 35 s.
const UDP_MAX_POLL_DELAY: Duration = Duration::from_secs(30); const UDP_MAX_POLL_DELAY: Duration = Duration::from_secs(30);
/// Cap on simultaneous UDP relay sessions per SOCKS5 ASSOCIATE. STUN
/// candidate gathering and DNS fanout produce dozens of distinct
/// targets; an abusive or runaway client could produce thousands.
/// 256 is generous for legitimate use and bounds tunnel-node UDP
/// sessions a single ASSOCIATE can hold open. On overflow we evict
/// the least-recently-inserted target (rough LRU — good enough for
/// long-tail eviction without tracking access on the hot path).
const MAX_UDP_SESSIONS_PER_ASSOCIATE: usize = 256;
/// Drop UDP datagrams larger than this (pre-base64). Standard MTU is
/// 1500B, jumbo frames are ~9000B; anything above that is either a
/// pathologically fragmented IP datagram or abusive traffic. Each
/// datagram carries ~33% base64 + JSON envelope overhead and consumes
/// Apps Script per-account quota, so a permissive ceiling here matters.
const MAX_UDP_PAYLOAD_BYTES: usize = 9 * 1024;
async fn handle_socks5_udp_associate( async fn handle_socks5_udp_associate(
mut control: TcpStream, mut control: TcpStream,
rewrite_ctx: Arc<RewriteCtx>, rewrite_ctx: Arc<RewriteCtx>,
@@ -623,15 +639,18 @@ async fn handle_socks5_udp_associate(
// Per RFC 1928 §6 the UDP relay only accepts datagrams from the // Per RFC 1928 §6 the UDP relay only accepts datagrams from the
// SOCKS5 client. We pin the source IP to the control TCP peer up // SOCKS5 client. We pin the source IP to the control TCP peer up
// front so a third party on the bind interface can't hijack the // front so a third party on the bind interface can't hijack the
// session by sending the first datagram. // session by sending the first datagram. THIS — not the bind IP
// below — is what actually keeps unauthenticated traffic out.
let client_peer_ip = control.peer_addr()?.ip(); let client_peer_ip = control.peer_addr()?.ip();
// The local TUN bridge talks to us over loopback. Binding the UDP relay // Bind the UDP relay to the same local IP the SOCKS5 client used
// there avoids exposing an unauthenticated UDP socket on LAN interfaces. // to reach the control TCP socket. `TcpStream::local_addr()` on an
let bind_ip = match control.local_addr()?.ip() { // accepted socket returns the concrete terminating address (e.g.
IpAddr::V4(ip) if ip.is_unspecified() => IpAddr::V4(Ipv4Addr::LOCALHOST), // 127.0.0.1 for a loopback client, 192.168.1.5 for a LAN client),
ip => ip, // not the listener's bind specifier — so this naturally tracks the
}; // path the client took. Source-IP filtering above is the security
// boundary; the bind choice is just about reachability.
let bind_ip = control.local_addr()?.ip();
let udp = Arc::new(UdpSocket::bind(SocketAddr::new(bind_ip, 0)).await?); let udp = Arc::new(UdpSocket::bind(SocketAddr::new(bind_ip, 0)).await?);
write_socks5_reply(&mut control, 0x00, Some(udp.local_addr()?)).await?; write_socks5_reply(&mut control, 0x00, Some(udp.local_addr()?)).await?;
tracing::info!( tracing::info!(
@@ -645,6 +664,12 @@ async fn handle_socks5_udp_associate(
let mut client_addr: Option<SocketAddr> = None; let mut client_addr: Option<SocketAddr> = None;
let sessions: Arc<Mutex<HashMap<SocksUdpTarget, UdpRelaySession>>> = let sessions: Arc<Mutex<HashMap<SocksUdpTarget, UdpRelaySession>>> =
Arc::new(Mutex::new(HashMap::new())); Arc::new(Mutex::new(HashMap::new()));
// Insertion-order log used for LRU-on-overflow eviction. We track
// it alongside (rather than inside) the HashMap so eviction can
// run under the same lock as the cap check.
let mut insertion_order: VecDeque<SocksUdpTarget> = VecDeque::new();
let mut oversized_dropped: u64 = 0;
let mut sessions_evicted: u64 = 0;
loop { loop {
tokio::select! { tokio::select! {
@@ -674,6 +699,23 @@ async fn handle_socks5_udp_associate(
let Some((target, payload)) = parse_socks5_udp_packet(&buf[..n]) else { let Some((target, payload)) = parse_socks5_udp_packet(&buf[..n]) else {
continue; continue;
}; };
// Size guard: drop oversize datagrams before they reach
// the mux. Each datagram costs ~payload * 1.33 in the
// batched JSON envelope plus tunnel-node CPU; uncapped,
// a runaway client can exhaust Apps Script quota.
if payload.len() > MAX_UDP_PAYLOAD_BYTES {
oversized_dropped += 1;
if oversized_dropped == 1 || oversized_dropped.is_multiple_of(100) {
tracing::debug!(
"udp datagram dropped: {} B > {} B (count={})",
payload.len(),
MAX_UDP_PAYLOAD_BYTES,
oversized_dropped,
);
}
continue;
}
let payload = payload.to_vec(); let payload = payload.to_vec();
// Fast path: existing session — push payload onto its // Fast path: existing session — push payload onto its
@@ -686,6 +728,35 @@ async fn handle_socks5_udp_associate(
} }
} }
// Cap reached → evict the oldest session before opening
// a new one. The evicted target's `UdpRelaySession` is
// dropped here, which closes its uplink channel; the
// task then exits its select! and tells tunnel-node to
// close. Any in-flight uplink already in the channel is
// delivered before the task exits.
{
let mut sess = sessions.lock().await;
while sess.len() >= MAX_UDP_SESSIONS_PER_ASSOCIATE {
let Some(victim) = insertion_order.pop_front() else {
break;
};
if sess.remove(&victim).is_some() {
sessions_evicted += 1;
if sessions_evicted == 1
|| sessions_evicted.is_multiple_of(50)
{
tracing::debug!(
"udp session cap {} reached; evicted {}:{} (total evicted={})",
MAX_UDP_SESSIONS_PER_ASSOCIATE,
victim.host,
victim.port,
sessions_evicted,
);
}
}
}
}
// New target: open via tunnel-node and spawn the per-session // New target: open via tunnel-node and spawn the per-session
// task. The first datagram rides the udp_open op so we // task. The first datagram rides the udp_open op so we
// save one round trip on session establishment. // save one round trip on session establishment.
@@ -734,6 +805,7 @@ async fn handle_socks5_udp_associate(
task_sessions.lock().await.remove(&task_target); task_sessions.lock().await.remove(&task_target);
}); });
insertion_order.push_back(target.clone());
sessions sessions
.lock() .lock()
.await .await
@@ -832,7 +904,20 @@ async fn send_udp_response_packets(
}; };
for packet in packets { for packet in packets {
let framed = build_socks5_udp_packet(target, &packet); let framed = build_socks5_udp_packet(target, &packet);
let _ = udp.send_to(&framed, client_addr).await; if let Err(e) = udp.send_to(&framed, client_addr).await {
// Errors here mean the local socket can't reach the SOCKS5
// client (ENETUNREACH, EHOSTDOWN, etc.). Surface at debug
// so a "my UDP traffic isn't coming back" report has
// something to grep for; volume is bounded by what we'd
// have delivered anyway.
tracing::debug!(
"udp send to client {} failed for {}:{}: {}",
client_addr,
target.host,
target.port,
e,
);
}
} }
} }
@@ -890,7 +975,12 @@ fn parse_socks5_udp_packet(buf: &[u8]) -> Option<(SocksUdpTarget, &[u8])> {
} }
let addr = buf[pos..pos + len].to_vec(); let addr = buf[pos..pos + len].to_vec();
pos += len; pos += len;
(String::from_utf8_lossy(&addr).into_owned(), addr) // Reject non-UTF-8 hostnames at the parser. Lossy decoding
// would forward U+FFFD into DNS and trigger an opaque
// NXDOMAIN — failing fast here gives us a clean parse-level
// drop that the test suite can assert on.
let host = std::str::from_utf8(&addr).ok()?.to_owned();
(host, addr)
} }
0x04 => { 0x04 => {
if buf.len() < pos + 16 + 2 { if buf.len() < pos + 16 + 2 {
@@ -2026,6 +2116,15 @@ mod tests {
assert!(parse_socks5_udp_packet(&raw).is_none()); assert!(parse_socks5_udp_packet(&raw).is_none());
} }
#[test]
fn socks5_udp_rejects_non_utf8_domain() {
// Lone continuation byte (0x80) — not valid UTF-8. Lossy decode
// would forward U+FFFD into DNS; strict parse should reject so
// we fail fast instead of issuing a doomed lookup.
let raw = [0, 0, 0, 0x03, 1, 0x80, 0, 80];
assert!(parse_socks5_udp_packet(&raw).is_none());
}
#[test] #[test]
fn socks5_udp_rejects_truncated_inputs() { fn socks5_udp_rejects_truncated_inputs() {
// Header alone is not enough. // Header alone is not enough.
+134 -16
View File
@@ -123,6 +123,11 @@ struct UdpSessionInner {
packets: Mutex<VecDeque<Vec<u8>>>, packets: Mutex<VecDeque<Vec<u8>>>,
last_active: Mutex<Instant>, last_active: Mutex<Instant>,
notify: Notify, notify: Notify,
/// Set when the upstream socket dies (recv error). Mirrors TCP's
/// `eof`: once true, subsequent batch drains return `eof: Some(true)`
/// so the proxy-side session task knows to exit instead of polling
/// a zombie session until the 120 s idle reaper kills it.
eof: AtomicBool,
/// Total datagrams dropped because the queue hit `UDP_QUEUE_LIMIT`. /// Total datagrams dropped because the queue hit `UDP_QUEUE_LIMIT`.
/// Surfaced via tracing so operators can correlate "choppy call" /// Surfaced via tracing so operators can correlate "choppy call"
/// reports with relay backpressure. /// reports with relay backpressure.
@@ -209,6 +214,7 @@ async fn create_udp_session(host: &str, port: u16) -> std::io::Result<ManagedUdp
packets: Mutex::new(VecDeque::with_capacity(UDP_QUEUE_LIMIT)), packets: Mutex::new(VecDeque::with_capacity(UDP_QUEUE_LIMIT)),
last_active: Mutex::new(Instant::now()), last_active: Mutex::new(Instant::now()),
notify: Notify::new(), notify: Notify::new(),
eof: AtomicBool::new(false),
queue_drops: AtomicU64::new(0), queue_drops: AtomicU64::new(0),
}); });
@@ -254,7 +260,16 @@ async fn udp_reader_task(socket: Arc<UdpSocket>, session: Arc<UdpSessionInner>)
*session.last_active.lock().await = Instant::now(); *session.last_active.lock().await = Instant::now();
session.notify.notify_one(); session.notify.notify_one();
} }
Err(_) => break, Err(e) => {
// Upstream socket died (ICMP unreachable on a connected
// socket, container netns torn down, etc.). Surface eof
// so the proxy-side session task can exit on its next
// poll instead of looping until the idle reaper.
tracing::debug!("udp upstream recv error: {} — marking session eof", e);
session.eof.store(true, Ordering::Release);
session.notify.notify_one();
break;
}
} }
} }
} }
@@ -362,14 +377,19 @@ async fn is_any_drainable(inners: &[Arc<SessionInner>]) -> bool {
} }
/// Drain whatever UDP datagrams are currently queued — no waiting. /// Drain whatever UDP datagrams are currently queued — no waiting.
async fn drain_udp_now(session: &UdpSessionInner) -> Vec<Vec<u8>> { /// Returns the eof flag alongside packets so the batch handler can
/// surface upstream-socket death without an extra round-trip.
async fn drain_udp_now(session: &UdpSessionInner) -> (Vec<Vec<u8>>, bool) {
let mut packets = session.packets.lock().await; let mut packets = session.packets.lock().await;
packets.drain(..).collect() let drained: Vec<Vec<u8>> = packets.drain(..).collect();
let eof = session.eof.load(Ordering::Acquire);
(drained, eof)
} }
/// UDP analogue of `wait_for_any_drainable`. Wakes when any session has /// UDP analogue of `wait_for_any_drainable`. Wakes when any session has
/// at least one queued packet. Same race-safety contract: watchers /// at least one queued packet OR has been marked eof. Same race-safety
/// self-filter against observable state to ignore stale permits. /// contract: watchers self-filter against observable state to ignore
/// stale permits.
async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: Duration) { async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: Duration) {
if inners.is_empty() { if inners.is_empty() {
return; return;
@@ -383,6 +403,9 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
watchers.push(tokio::spawn(async move { watchers.push(tokio::spawn(async move {
loop { loop {
inner.notify.notified().await; inner.notify.notified().await;
if inner.eof.load(Ordering::Acquire) {
break;
}
if !inner.packets.lock().await.is_empty() { if !inner.packets.lock().await.is_empty() {
break; break;
} }
@@ -409,6 +432,9 @@ async fn wait_for_any_udp_drainable(inners: &[Arc<UdpSessionInner>], deadline: D
async fn is_any_udp_drainable(inners: &[Arc<UdpSessionInner>]) -> bool { async fn is_any_udp_drainable(inners: &[Arc<UdpSessionInner>]) -> bool {
for inner in inners { for inner in inners {
if inner.eof.load(Ordering::Acquire) {
return true;
}
if !inner.packets.lock().await.is_empty() { if !inner.packets.lock().await.is_empty() {
return true; return true;
} }
@@ -642,7 +668,13 @@ async fn handle_batch(
}); });
} }
"udp_open" => { "udp_open" => {
had_writes_or_connects = true; // An open *with* an initial datagram is real upstream
// work; an open without one (rare — current proxy
// never invokes it that way) is just resource alloc
// and shouldn't suppress long-poll on sibling polls.
if op.d.as_deref().map(|d| !d.is_empty()).unwrap_or(false) {
had_writes_or_connects = true;
}
let state = state.clone(); let state = state.clone();
let host = op.host.clone(); let host = op.host.clone();
let port = op.port; let port = op.port;
@@ -830,13 +862,30 @@ async fn handle_batch(
// ---- UDP drain ---- // ---- UDP drain ----
if !udp_drains.is_empty() { if !udp_drains.is_empty() {
let sessions = state.udp_sessions.lock().await; {
for (i, sid) in &udp_drains { let sessions = state.udp_sessions.lock().await;
if let Some(session) = sessions.get(sid) { for (i, sid) in &udp_drains {
let packets = drain_udp_now(&session.inner).await; if let Some(session) = sessions.get(sid) {
results.push((*i, udp_drain_response(sid.clone(), packets))); let (packets, eof) = drain_udp_now(&session.inner).await;
} else { results.push((*i, udp_drain_response(sid.clone(), packets, eof)));
results.push((*i, eof_response(sid.clone()))); } else {
results.push((*i, eof_response(sid.clone())));
}
}
}
// Clean up eof UDP sessions so a future batch with the same
// sid gets the "session not found" eof immediately rather
// than re-checking the (already-stale) eof flag.
let mut sessions = state.udp_sessions.lock().await;
for (_, sid) in &udp_drains {
if let Some(s) = sessions.get(sid) {
if s.inner.eof.load(Ordering::Acquire) {
if let Some(s) = sessions.remove(sid) {
s.reader_handle.abort();
tracing::info!("udp session {} closed by remote (batch)", sid);
}
}
} }
} }
} }
@@ -863,7 +912,7 @@ fn tcp_drain_response(sid: String, data: Vec<u8>, eof: bool) -> TunnelResponse {
} }
} }
fn udp_drain_response(sid: String, packets: Vec<Vec<u8>>) -> TunnelResponse { fn udp_drain_response(sid: String, packets: Vec<Vec<u8>>, eof: bool) -> TunnelResponse {
let pkts = if packets.is_empty() { let pkts = if packets.is_empty() {
None None
} else { } else {
@@ -873,7 +922,7 @@ fn udp_drain_response(sid: String, packets: Vec<Vec<u8>>) -> TunnelResponse {
sid: Some(sid), sid: Some(sid),
d: None, d: None,
pkts, pkts,
eof: Some(false), eof: Some(eof),
e: None, e: None,
code: None, code: None,
} }
@@ -1767,8 +1816,9 @@ mod tests {
assert!(state.udp_sessions.lock().await.contains_key(&sid)); assert!(state.udp_sessions.lock().await.contains_key(&sid));
wait_for_any_udp_drainable(std::slice::from_ref(&inner), Duration::from_secs(2)).await; wait_for_any_udp_drainable(std::slice::from_ref(&inner), Duration::from_secs(2)).await;
let packets = drain_udp_now(&inner).await; let (packets, eof) = drain_udp_now(&inner).await;
assert_eq!(packets, vec![b"ECHO: ping".to_vec()]); assert_eq!(packets, vec![b"ECHO: ping".to_vec()]);
assert!(!eof);
} }
/// When the upstream sends faster than the relay drains, the queue /// When the upstream sends faster than the relay drains, the queue
@@ -1871,4 +1921,72 @@ mod tests {
let decoded = B64.decode(tcp_d).unwrap(); let decoded = B64.decode(tcp_d).unwrap();
assert_eq!(&decoded[..], b"DELAYED"); assert_eq!(&decoded[..], b"DELAYED");
} }
/// When the upstream UDP socket dies (recv error), the reader_task
/// must mark the session eof so subsequent batches return
/// `eof: true` instead of looping the proxy on a zombie session.
#[tokio::test]
async fn udp_drain_surfaces_upstream_eof() {
let inner = Arc::new(UdpSessionInner {
socket: Arc::new(UdpSocket::bind(("127.0.0.1", 0)).await.unwrap()),
packets: Mutex::new(VecDeque::new()),
last_active: Mutex::new(Instant::now()),
notify: Notify::new(),
eof: AtomicBool::new(false),
queue_drops: AtomicU64::new(0),
});
// Healthy state: drain reports no eof.
let (pkts, eof) = drain_udp_now(&inner).await;
assert!(pkts.is_empty());
assert!(!eof);
// Simulate the failure path udp_reader_task takes on socket err.
inner.eof.store(true, Ordering::Release);
inner.notify.notify_one();
let (pkts, eof) = drain_udp_now(&inner).await;
assert!(pkts.is_empty());
assert!(eof, "drain should surface eof once the reader marks it");
// wait_for_any_udp_drainable also wakes immediately on eof.
let t0 = Instant::now();
wait_for_any_udp_drainable(std::slice::from_ref(&inner), Duration::from_secs(5)).await;
assert!(
t0.elapsed() < Duration::from_millis(100),
"eof should short-circuit the wait, took {:?}",
t0.elapsed()
);
// The `udp_drain_response` helper threads eof into `eof: Some(true)`.
let resp = udp_drain_response("zombie".into(), pkts, eof);
assert_eq!(resp.eof, Some(true));
assert!(resp.pkts.is_none());
}
/// A batch that targets a UDP session reaped by the cleanup task
/// (or removed via close) returns `eof: true` so the proxy task
/// exits its select loop instead of polling a zombie.
#[tokio::test]
async fn udp_data_for_missing_session_returns_eof() {
use axum::body::Bytes;
use axum::extract::State;
let state = fresh_state();
let body = serde_json::json!({
"k": "test-key",
"ops": [
{"op": "udp_data", "sid": "does-not-exist"},
]
})
.to_string();
let resp = handle_batch(State(state.clone()), Bytes::from(body))
.await
.into_response();
let (_parts, body) = resp.into_parts();
let body_bytes = axum::body::to_bytes(body, 64 * 1024).await.unwrap();
let parsed: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap();
let r = parsed["r"].as_array().unwrap();
assert_eq!(r.len(), 1);
assert_eq!(r[0]["eof"], serde_json::Value::Bool(true));
}
} }