proxy_server: support chunked request bodies (#21)

Teach the incoming HTTP request parser to handle Transfer-Encoding:
chunked instead of only Content-Length-framed bodies.

Also reply with 100 Continue when a client sends Expect:
100-continue before waiting for the request body.

This keeps request framing correct for POST/PUT-style clients and
adds focused tests for chunked decoding and 100-continue handling.

Co-authored-by: freeinternet865 <free@internet865.com>
This commit is contained in:
freeinternet865
2026-04-23 02:35:22 +03:30
committed by GitHub
parent 520ac46de7
commit 4cfd9d9652
+219 -12
View File
@@ -950,7 +950,6 @@ where
None => return Ok(false),
};
// Read body if content-length is set.
let body = read_body(stream, &leftover, &headers).await?;
let default_port = if scheme == "https" { 443 } else { 80 };
@@ -1006,36 +1005,178 @@ where
}
}
fn header_value<'a>(headers: &'a [(String, String)], name: &str) -> Option<&'a str> {
headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_str())
}
fn expects_100_continue(headers: &[(String, String)]) -> bool {
header_value(headers, "expect")
.map(|v| {
v.split(',')
.any(|part| part.trim().eq_ignore_ascii_case("100-continue"))
})
.unwrap_or(false)
}
fn invalid_body(msg: impl Into<String>) -> std::io::Error {
std::io::Error::new(std::io::ErrorKind::InvalidData, msg.into())
}
async fn read_body<S>(
stream: &mut S,
leftover: &[u8],
headers: &[(String, String)],
) -> std::io::Result<Vec<u8>>
where
S: tokio::io::AsyncRead + Unpin,
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
{
let cl: Option<usize> = headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case("content-length"))
.and_then(|(_, v)| v.parse().ok());
let transfer_encoding = header_value(headers, "transfer-encoding");
let is_chunked = transfer_encoding
.map(|v| {
v.split(',')
.any(|part| part.trim().eq_ignore_ascii_case("chunked"))
})
.unwrap_or(false);
let Some(cl) = cl else {
let content_length = match header_value(headers, "content-length") {
Some(v) => Some(
v.parse::<usize>()
.map_err(|_| invalid_body(format!("invalid Content-Length: {}", v)))?,
),
None => None,
};
if transfer_encoding.is_some() && !is_chunked {
return Err(invalid_body(format!(
"unsupported Transfer-Encoding: {}",
transfer_encoding.unwrap_or_default()
)));
}
if is_chunked && content_length.is_some() {
return Err(invalid_body(
"both Transfer-Encoding: chunked and Content-Length are present",
));
}
if expects_100_continue(headers) && (is_chunked || content_length.is_some()) {
stream.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
stream.flush().await?;
}
if is_chunked {
return read_chunked_request_body(stream, leftover.to_vec()).await;
}
let Some(content_length) = content_length else {
return Ok(Vec::new());
};
let mut body = Vec::with_capacity(cl);
body.extend_from_slice(&leftover[..leftover.len().min(cl)]);
let mut body = Vec::with_capacity(content_length);
body.extend_from_slice(&leftover[..leftover.len().min(content_length)]);
let mut tmp = [0u8; 8192];
while body.len() < cl {
while body.len() < content_length {
let n = stream.read(&mut tmp).await?;
if n == 0 {
break;
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"EOF mid-body",
));
}
let need = cl - body.len();
let need = content_length - body.len();
body.extend_from_slice(&tmp[..n.min(need)]);
}
Ok(body)
}
async fn read_chunked_request_body<S>(stream: &mut S, mut buf: Vec<u8>) -> std::io::Result<Vec<u8>>
where
S: tokio::io::AsyncRead + Unpin,
{
let mut out = Vec::new();
let mut tmp = [0u8; 8192];
loop {
let line = read_crlf_line(stream, &mut buf, &mut tmp).await?;
if line.is_empty() {
continue;
}
let line_str = std::str::from_utf8(&line)
.map_err(|_| invalid_body("non-utf8 chunk size line"))?
.trim();
let size_hex = line_str.split(';').next().unwrap_or("");
let size = usize::from_str_radix(size_hex, 16)
.map_err(|_| invalid_body(format!("bad chunk size '{}'", line_str)))?;
if size == 0 {
loop {
let trailer = read_crlf_line(stream, &mut buf, &mut tmp).await?;
if trailer.is_empty() {
return Ok(out);
}
}
}
fill_buffer(stream, &mut buf, &mut tmp, size + 2).await?;
if &buf[size..size + 2] != b"\r\n" {
return Err(invalid_body("chunk missing trailing CRLF"));
}
out.extend_from_slice(&buf[..size]);
buf.drain(..size + 2);
}
}
async fn read_crlf_line<S>(
stream: &mut S,
buf: &mut Vec<u8>,
tmp: &mut [u8],
) -> std::io::Result<Vec<u8>>
where
S: tokio::io::AsyncRead + Unpin,
{
loop {
if let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
let line = buf[..idx].to_vec();
buf.drain(..idx + 2);
return Ok(line);
}
let n = stream.read(tmp).await?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"EOF in chunked body",
));
}
buf.extend_from_slice(&tmp[..n]);
}
}
async fn fill_buffer<S>(
stream: &mut S,
buf: &mut Vec<u8>,
tmp: &mut [u8],
want: usize,
) -> std::io::Result<()>
where
S: tokio::io::AsyncRead + Unpin,
{
while buf.len() < want {
let n = stream.read(tmp).await?;
if n == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"EOF in chunked body",
));
}
buf.extend_from_slice(&tmp[..n]);
}
Ok(())
}
// ---------- Plain HTTP proxy ----------
async fn do_plain_http(
@@ -1068,3 +1209,69 @@ async fn do_plain_http(
sock.flush().await?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
fn headers(pairs: &[(&str, &str)]) -> Vec<(String, String)> {
pairs
.iter()
.map(|(k, v)| ((*k).to_string(), (*v).to_string()))
.collect()
}
#[tokio::test(flavor = "current_thread")]
async fn read_body_decodes_chunked_request() {
let (mut client, mut server) = duplex(1024);
let writer = tokio::spawn(async move {
client
.write_all(b"llo\r\n6\r\n world\r\n0\r\nFoo: bar\r\n\r\n")
.await
.unwrap();
});
let body = read_body(
&mut server,
b"5\r\nhe",
&headers(&[("Transfer-Encoding", "chunked")]),
)
.await
.unwrap();
writer.await.unwrap();
assert_eq!(body, b"hello world");
}
#[tokio::test(flavor = "current_thread")]
async fn read_body_sends_100_continue_before_waiting_for_body() {
let (mut client, mut server) = duplex(1024);
let client_task = tokio::spawn(async move {
let mut got = Vec::new();
let mut tmp = [0u8; 64];
loop {
let n = client.read(&mut tmp).await.unwrap();
assert!(n > 0, "proxy closed before sending 100 Continue");
got.extend_from_slice(&tmp[..n]);
if got.windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
assert_eq!(got, b"HTTP/1.1 100 Continue\r\n\r\n");
client.write_all(b"hello").await.unwrap();
});
let body = read_body(
&mut server,
&[],
&headers(&[("Content-Length", "5"), ("Expect", "100-continue")]),
)
.await
.unwrap();
client_task.await.unwrap();
assert_eq!(body, b"hello");
}
}