use std::sync::Arc; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::Mutex; use tokio_rustls::TlsAcceptor; use crate::config::Config; use crate::domain_fronter::DomainFronter; use crate::mitm::MitmCertManager; #[derive(Debug, thiserror::Error)] pub enum ProxyError { #[error("io: {0}")] Io(#[from] std::io::Error), } pub struct ProxyServer { host: String, port: u16, fronter: Arc, mitm: Arc>, } impl ProxyServer { pub fn new(config: &Config, mitm: Arc>) -> Result { let fronter = DomainFronter::new(config) .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, format!("{e}")))?; Ok(Self { host: config.listen_host.clone(), port: config.listen_port, fronter: Arc::new(fronter), mitm, }) } pub async fn run(self) -> Result<(), ProxyError> { let addr = format!("{}:{}", self.host, self.port); let listener = TcpListener::bind(&addr).await?; tracing::warn!( "Listening on {} — set your browser HTTP proxy to this address.", addr ); loop { let (sock, peer) = match listener.accept().await { Ok(x) => x, Err(e) => { tracing::error!("accept error: {}", e); continue; } }; let _ = sock.set_nodelay(true); let fronter = self.fronter.clone(); let mitm = self.mitm.clone(); tokio::spawn(async move { if let Err(e) = handle_client(sock, fronter, mitm).await { tracing::debug!("client {} closed: {}", peer, e); } }); } } } async fn handle_client( mut sock: TcpStream, fronter: Arc, mitm: Arc>, ) -> std::io::Result<()> { // Read the first request (head only). let (head, leftover) = match read_http_head(&mut sock).await? { Some(v) => v, None => return Ok(()), }; let (method, target, _version, _headers) = parse_request_head(&head) .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidData, "bad request"))?; if method.eq_ignore_ascii_case("CONNECT") { do_connect(sock, &target, fronter, mitm).await } else { do_plain_http(sock, &head, &leftover, fronter).await } } /// Read an HTTP head (request line + headers) up to the first \r\n\r\n. /// Returns (head_bytes, leftover_after_head). The leftover may contain part /// of the request body already received. async fn read_http_head(sock: &mut TcpStream) -> std::io::Result, Vec)>> { let mut buf = Vec::with_capacity(4096); let mut tmp = [0u8; 4096]; loop { let n = sock.read(&mut tmp).await?; if n == 0 { return if buf.is_empty() { Ok(None) } else { Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, "EOF mid-header", )) }; } buf.extend_from_slice(&tmp[..n]); if let Some(pos) = find_headers_end(&buf) { let head = buf[..pos].to_vec(); let leftover = buf[pos..].to_vec(); return Ok(Some((head, leftover))); } if buf.len() > 1024 * 1024 { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "headers too large", )); } } } fn find_headers_end(buf: &[u8]) -> Option { buf.windows(4).position(|w| w == b"\r\n\r\n").map(|p| p + 4) } fn parse_request_head(head: &[u8]) -> Option<(String, String, String, Vec<(String, String)>)> { let s = std::str::from_utf8(head).ok()?; let mut lines = s.split("\r\n"); let first = lines.next()?; let mut parts = first.splitn(3, ' '); let method = parts.next()?.to_string(); let target = parts.next()?.to_string(); let version = parts.next().unwrap_or("HTTP/1.1").to_string(); let mut headers = Vec::new(); for l in lines { if l.is_empty() { break; } if let Some((k, v)) = l.split_once(':') { headers.push((k.trim().to_string(), v.trim().to_string())); } } Some((method, target, version, headers)) } // ---------- CONNECT handling ---------- async fn do_connect( mut sock: TcpStream, target: &str, fronter: Arc, mitm: Arc>, ) -> std::io::Result<()> { let (host, port) = parse_host_port(target); tracing::info!("CONNECT -> {}:{}", host, port); sock.write_all(b"HTTP/1.1 200 Connection Established\r\n\r\n").await?; sock.flush().await?; // MITM: build a server config for this domain and accept TLS. let server_config = { let mut m = mitm.lock().await; match m.get_server_config(&host) { Ok(c) => c, Err(e) => { tracing::error!("cert gen failed for {}: {}", host, e); return Ok(()); } } }; let acceptor = TlsAcceptor::from(server_config); let mut tls = match acceptor.accept(sock).await { Ok(t) => t, Err(e) => { tracing::debug!("TLS accept failed for {}: {}", host, e); return Ok(()); } }; // Keep-alive loop: read HTTP requests from the decrypted stream. loop { match handle_mitm_request(&mut tls, &host, port, &fronter).await { Ok(true) => continue, Ok(false) => break, Err(e) => { tracing::debug!("MITM handler error for {}: {}", host, e); break; } } } Ok(()) } fn parse_host_port(target: &str) -> (String, u16) { if let Some((h, p)) = target.rsplit_once(':') { let port: u16 = p.parse().unwrap_or(443); (h.to_string(), port) } else { (target.to_string(), 443) } } async fn handle_mitm_request( stream: &mut S, host: &str, port: u16, fronter: &DomainFronter, ) -> std::io::Result where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin, { let (head, leftover) = match read_http_head_io(stream).await? { Some(v) => v, None => return Ok(false), }; let (method, path, _version, headers) = match parse_request_head(&head) { Some(v) => v, None => return Ok(false), }; // Read body if content-length is set. let body = read_body(stream, &leftover, &headers).await?; let url = if port == 443 { format!("https://{}{}", host, path) } else { format!("https://{}:{}{}", host, port, path) }; tracing::info!("MITM {} {}", method, url); let response = fronter.relay(&method, &url, &headers, &body).await; stream.write_all(&response).await?; stream.flush().await?; // Keep-alive unless the client asked to close. let connection_close = headers .iter() .any(|(k, v)| k.eq_ignore_ascii_case("connection") && v.eq_ignore_ascii_case("close")); Ok(!connection_close) } async fn read_http_head_io(stream: &mut S) -> std::io::Result, Vec)>> where S: tokio::io::AsyncRead + Unpin, { let mut buf = Vec::with_capacity(4096); let mut tmp = [0u8; 4096]; loop { let n = stream.read(&mut tmp).await?; if n == 0 { return if buf.is_empty() { Ok(None) } else { Err(std::io::Error::new( std::io::ErrorKind::UnexpectedEof, "EOF mid-header", )) }; } buf.extend_from_slice(&tmp[..n]); if let Some(pos) = find_headers_end(&buf) { let head = buf[..pos].to_vec(); let leftover = buf[pos..].to_vec(); return Ok(Some((head, leftover))); } if buf.len() > 1024 * 1024 { return Err(std::io::Error::new( std::io::ErrorKind::InvalidData, "headers too large", )); } } } async fn read_body( stream: &mut S, leftover: &[u8], headers: &[(String, String)], ) -> std::io::Result> where S: tokio::io::AsyncRead + Unpin, { let cl: Option = headers .iter() .find(|(k, _)| k.eq_ignore_ascii_case("content-length")) .and_then(|(_, v)| v.parse().ok()); let Some(cl) = cl else { return Ok(Vec::new()); }; let mut body = Vec::with_capacity(cl); body.extend_from_slice(&leftover[..leftover.len().min(cl)]); let mut tmp = [0u8; 8192]; while body.len() < cl { let n = stream.read(&mut tmp).await?; if n == 0 { break; } let need = cl - body.len(); body.extend_from_slice(&tmp[..n.min(need)]); } Ok(body) } // ---------- Plain HTTP proxy ---------- async fn do_plain_http( mut sock: TcpStream, head: &[u8], leftover: &[u8], fronter: Arc, ) -> std::io::Result<()> { let (method, target, _version, headers) = parse_request_head(head) .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidData, "bad request"))?; let body = read_body(&mut sock, leftover, &headers).await?; // Browser sends `GET http://example.com/path HTTP/1.1` on plain proxy. let url = if target.starts_with("http://") || target.starts_with("https://") { target.clone() } else { // Fallback: stitch Host header with path. let host = headers .iter() .find(|(k, _)| k.eq_ignore_ascii_case("host")) .map(|(_, v)| v.clone()) .unwrap_or_default(); format!("http://{}{}", host, target) }; tracing::info!("HTTP {} {}", method, url); let response = fronter.relay(&method, &url, &headers, &body).await; sock.write_all(&response).await?; sock.flush().await?; Ok(()) }