mirror of
https://github.com/therealaleph/MasterHttpRelayVPN-RUST.git
synced 2026-05-17 21:24:48 +03:00
proxy_server: support chunked request bodies (#21)
Teach the incoming HTTP request parser to handle Transfer-Encoding: chunked instead of only Content-Length-framed bodies. Also reply with 100 Continue when a client sends Expect: 100-continue before waiting for the request body. This keeps request framing correct for POST/PUT-style clients and adds focused tests for chunked decoding and 100-continue handling. Co-authored-by: freeinternet865 <free@internet865.com>
This commit is contained in:
+219
-12
@@ -950,7 +950,6 @@ where
|
||||
None => return Ok(false),
|
||||
};
|
||||
|
||||
// Read body if content-length is set.
|
||||
let body = read_body(stream, &leftover, &headers).await?;
|
||||
|
||||
let default_port = if scheme == "https" { 443 } else { 80 };
|
||||
@@ -1006,36 +1005,178 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
fn header_value<'a>(headers: &'a [(String, String)], name: &str) -> Option<&'a str> {
|
||||
headers
|
||||
.iter()
|
||||
.find(|(k, _)| k.eq_ignore_ascii_case(name))
|
||||
.map(|(_, v)| v.as_str())
|
||||
}
|
||||
|
||||
fn expects_100_continue(headers: &[(String, String)]) -> bool {
|
||||
header_value(headers, "expect")
|
||||
.map(|v| {
|
||||
v.split(',')
|
||||
.any(|part| part.trim().eq_ignore_ascii_case("100-continue"))
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn invalid_body(msg: impl Into<String>) -> std::io::Error {
|
||||
std::io::Error::new(std::io::ErrorKind::InvalidData, msg.into())
|
||||
}
|
||||
|
||||
async fn read_body<S>(
|
||||
stream: &mut S,
|
||||
leftover: &[u8],
|
||||
headers: &[(String, String)],
|
||||
) -> std::io::Result<Vec<u8>>
|
||||
where
|
||||
S: tokio::io::AsyncRead + Unpin,
|
||||
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
|
||||
{
|
||||
let cl: Option<usize> = headers
|
||||
.iter()
|
||||
.find(|(k, _)| k.eq_ignore_ascii_case("content-length"))
|
||||
.and_then(|(_, v)| v.parse().ok());
|
||||
let transfer_encoding = header_value(headers, "transfer-encoding");
|
||||
let is_chunked = transfer_encoding
|
||||
.map(|v| {
|
||||
v.split(',')
|
||||
.any(|part| part.trim().eq_ignore_ascii_case("chunked"))
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
let Some(cl) = cl else {
|
||||
let content_length = match header_value(headers, "content-length") {
|
||||
Some(v) => Some(
|
||||
v.parse::<usize>()
|
||||
.map_err(|_| invalid_body(format!("invalid Content-Length: {}", v)))?,
|
||||
),
|
||||
None => None,
|
||||
};
|
||||
|
||||
if transfer_encoding.is_some() && !is_chunked {
|
||||
return Err(invalid_body(format!(
|
||||
"unsupported Transfer-Encoding: {}",
|
||||
transfer_encoding.unwrap_or_default()
|
||||
)));
|
||||
}
|
||||
|
||||
if is_chunked && content_length.is_some() {
|
||||
return Err(invalid_body(
|
||||
"both Transfer-Encoding: chunked and Content-Length are present",
|
||||
));
|
||||
}
|
||||
|
||||
if expects_100_continue(headers) && (is_chunked || content_length.is_some()) {
|
||||
stream.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
|
||||
stream.flush().await?;
|
||||
}
|
||||
|
||||
if is_chunked {
|
||||
return read_chunked_request_body(stream, leftover.to_vec()).await;
|
||||
}
|
||||
|
||||
let Some(content_length) = content_length else {
|
||||
return Ok(Vec::new());
|
||||
};
|
||||
let mut body = Vec::with_capacity(cl);
|
||||
body.extend_from_slice(&leftover[..leftover.len().min(cl)]);
|
||||
|
||||
let mut body = Vec::with_capacity(content_length);
|
||||
body.extend_from_slice(&leftover[..leftover.len().min(content_length)]);
|
||||
let mut tmp = [0u8; 8192];
|
||||
while body.len() < cl {
|
||||
while body.len() < content_length {
|
||||
let n = stream.read(&mut tmp).await?;
|
||||
if n == 0 {
|
||||
break;
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"EOF mid-body",
|
||||
));
|
||||
}
|
||||
let need = cl - body.len();
|
||||
let need = content_length - body.len();
|
||||
body.extend_from_slice(&tmp[..n.min(need)]);
|
||||
}
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
async fn read_chunked_request_body<S>(stream: &mut S, mut buf: Vec<u8>) -> std::io::Result<Vec<u8>>
|
||||
where
|
||||
S: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
let mut out = Vec::new();
|
||||
let mut tmp = [0u8; 8192];
|
||||
|
||||
loop {
|
||||
let line = read_crlf_line(stream, &mut buf, &mut tmp).await?;
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let line_str = std::str::from_utf8(&line)
|
||||
.map_err(|_| invalid_body("non-utf8 chunk size line"))?
|
||||
.trim();
|
||||
let size_hex = line_str.split(';').next().unwrap_or("");
|
||||
let size = usize::from_str_radix(size_hex, 16)
|
||||
.map_err(|_| invalid_body(format!("bad chunk size '{}'", line_str)))?;
|
||||
|
||||
if size == 0 {
|
||||
loop {
|
||||
let trailer = read_crlf_line(stream, &mut buf, &mut tmp).await?;
|
||||
if trailer.is_empty() {
|
||||
return Ok(out);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fill_buffer(stream, &mut buf, &mut tmp, size + 2).await?;
|
||||
if &buf[size..size + 2] != b"\r\n" {
|
||||
return Err(invalid_body("chunk missing trailing CRLF"));
|
||||
}
|
||||
out.extend_from_slice(&buf[..size]);
|
||||
buf.drain(..size + 2);
|
||||
}
|
||||
}
|
||||
|
||||
async fn read_crlf_line<S>(
|
||||
stream: &mut S,
|
||||
buf: &mut Vec<u8>,
|
||||
tmp: &mut [u8],
|
||||
) -> std::io::Result<Vec<u8>>
|
||||
where
|
||||
S: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
loop {
|
||||
if let Some(idx) = buf.windows(2).position(|w| w == b"\r\n") {
|
||||
let line = buf[..idx].to_vec();
|
||||
buf.drain(..idx + 2);
|
||||
return Ok(line);
|
||||
}
|
||||
let n = stream.read(tmp).await?;
|
||||
if n == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"EOF in chunked body",
|
||||
));
|
||||
}
|
||||
buf.extend_from_slice(&tmp[..n]);
|
||||
}
|
||||
}
|
||||
|
||||
async fn fill_buffer<S>(
|
||||
stream: &mut S,
|
||||
buf: &mut Vec<u8>,
|
||||
tmp: &mut [u8],
|
||||
want: usize,
|
||||
) -> std::io::Result<()>
|
||||
where
|
||||
S: tokio::io::AsyncRead + Unpin,
|
||||
{
|
||||
while buf.len() < want {
|
||||
let n = stream.read(tmp).await?;
|
||||
if n == 0 {
|
||||
return Err(std::io::Error::new(
|
||||
std::io::ErrorKind::UnexpectedEof,
|
||||
"EOF in chunked body",
|
||||
));
|
||||
}
|
||||
buf.extend_from_slice(&tmp[..n]);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------- Plain HTTP proxy ----------
|
||||
|
||||
async fn do_plain_http(
|
||||
@@ -1068,3 +1209,69 @@ async fn do_plain_http(
|
||||
sock.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::io::{duplex, AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
fn headers(pairs: &[(&str, &str)]) -> Vec<(String, String)> {
|
||||
pairs
|
||||
.iter()
|
||||
.map(|(k, v)| ((*k).to_string(), (*v).to_string()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn read_body_decodes_chunked_request() {
|
||||
let (mut client, mut server) = duplex(1024);
|
||||
let writer = tokio::spawn(async move {
|
||||
client
|
||||
.write_all(b"llo\r\n6\r\n world\r\n0\r\nFoo: bar\r\n\r\n")
|
||||
.await
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
let body = read_body(
|
||||
&mut server,
|
||||
b"5\r\nhe",
|
||||
&headers(&[("Transfer-Encoding", "chunked")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
writer.await.unwrap();
|
||||
assert_eq!(body, b"hello world");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "current_thread")]
|
||||
async fn read_body_sends_100_continue_before_waiting_for_body() {
|
||||
let (mut client, mut server) = duplex(1024);
|
||||
let client_task = tokio::spawn(async move {
|
||||
let mut got = Vec::new();
|
||||
let mut tmp = [0u8; 64];
|
||||
loop {
|
||||
let n = client.read(&mut tmp).await.unwrap();
|
||||
assert!(n > 0, "proxy closed before sending 100 Continue");
|
||||
got.extend_from_slice(&tmp[..n]);
|
||||
if got.windows(4).any(|w| w == b"\r\n\r\n") {
|
||||
break;
|
||||
}
|
||||
}
|
||||
assert_eq!(got, b"HTTP/1.1 100 Continue\r\n\r\n");
|
||||
client.write_all(b"hello").await.unwrap();
|
||||
});
|
||||
|
||||
let body = read_body(
|
||||
&mut server,
|
||||
&[],
|
||||
&headers(&[("Content-Length", "5"), ("Expect", "100-continue")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
client_task.await.unwrap();
|
||||
assert_eq!(body, b"hello");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user