diff --git a/src/domain_fronter.py b/src/domain_fronter.py index 3464602..ab7705e 100644 --- a/src/domain_fronter.py +++ b/src/domain_fronter.py @@ -851,6 +851,187 @@ class DomainFronter: result += "\r\n" return result.encode() + full_body + async def relay_parallel_streaming(self, method: str, url: str, + headers: dict, body: bytes, + writer, min_size: int = 0, + chunk_size: int = 256 * 1024, + max_parallel: int = 16) -> bytes: + """Stream large file download to client as chunks arrive. + + Downloads file in parallel chunks and streams to client immediately, + avoiding memory buildup and timeout issues for large files. + + Args: + min_size: Minimum file size to enable chunking (0 = no minimum) + chunk_size: Size of each chunk in bytes (default 256KB) + max_parallel: Maximum parallel chunk downloads (default 16) + """ + + if method != "GET" or body: + return await self.relay(method, url, headers, body) + + # Probe: first chunk with Range header + range_headers = dict(headers) if headers else {} + range_headers["Range"] = f"bytes=0-{chunk_size - 1}" + first_resp = await self.relay("GET", url, range_headers, b"") + + status, resp_hdrs, resp_body = self._split_raw_response(first_resp) + + # No range support → return single response + if status != 206: + return first_resp + + # Parse total size from Content-Range + content_range = resp_hdrs.get("content-range", "") + m = re.search(r"/(\d+)", content_range) + if not m: + return self._rewrite_206_to_200(first_resp) + + total_size = int(m.group(1)) + + # Check minimum size threshold (if configured) + if min_size > 0 and total_size < min_size: + log.debug( + "File size %d < min threshold %d, using single request", + total_size, min_size + ) + return self._rewrite_206_to_200(first_resp) + + # Small file (less than one chunk) → return immediately + if total_size <= chunk_size or len(resp_body) >= total_size: + return self._rewrite_206_to_200(first_resp) + + # Build response header + response_header = "HTTP/1.1 200 OK\r\n" + skip = {"transfer-encoding", "connection", "keep-alive", + "content-length", "content-encoding", "content-range"} + for k, v in resp_hdrs.items(): + if k.lower() not in skip: + response_header += f"{k}: {v}\r\n" + response_header += f"Content-Length: {total_size}\r\n\r\n" + + # Send header + first chunk immediately + writer.write(response_header.encode() + resp_body) + await writer.drain() + + log.info("Streaming download: %d bytes, %d chunks of %d KB", + total_size, (total_size + chunk_size - 1) // chunk_size, + chunk_size // 1024) + + # Calculate remaining ranges + ranges = [] + start = len(resp_body) + while start < total_size: + end = min(start + chunk_size - 1, total_size - 1) + ranges.append((start, end)) + start = end + 1 + + # Download and stream chunks in order + sem = asyncio.Semaphore(max_parallel) + chunk_buffer = {} # idx -> chunk_data + next_to_stream = 0 + stream_event = asyncio.Event() + all_downloaded = asyncio.Event() + + async def fetch_chunk(idx: int, s: int, e: int, max_retries: int = 5): + """Download chunk with retry and timeout.""" + async with sem: + rh = dict(headers) if headers else {} + rh["Range"] = f"bytes={s}-{e}" + expected_size = e - s + 1 + + for attempt in range(max_retries): + try: + # Low timeout for small chunks (256KB should download quickly) + raw = await asyncio.wait_for( + self.relay("GET", url, rh, b""), + timeout=15 + ) + _, _, chunk_body = self._split_raw_response(raw) + + # Verify chunk size + if len(chunk_body) != expected_size: + log.warning( + "Chunk %d size mismatch: got %d, expected %d (retry %d/%d)", + idx, len(chunk_body), expected_size, attempt + 1, max_retries + ) + if attempt < max_retries - 1: + await asyncio.sleep(0.5 * (attempt + 1)) + continue + + # Store chunk in buffer + chunk_buffer[idx] = chunk_body + stream_event.set() # Signal streamer + log.debug("Downloaded chunk %d/%d", idx + 1, len(ranges)) + return + + except asyncio.TimeoutError: + log.warning( + "Chunk %d timeout (retry %d/%d)", + idx, attempt + 1, max_retries + ) + if attempt < max_retries - 1: + await asyncio.sleep(0.5 * (attempt + 1)) + except Exception as ex: + log.warning( + "Chunk %d error: %s (retry %d/%d)", + idx, ex, attempt + 1, max_retries + ) + if attempt < max_retries - 1: + await asyncio.sleep(0.5 * (attempt + 1)) + + # All retries failed + raise Exception(f"Chunk {idx} failed after {max_retries} retries") + + async def stream_chunks(): + """Stream chunks to client in sequential order.""" + nonlocal next_to_stream + try: + while next_to_stream < len(ranges): + # Wait for next chunk to be available + while next_to_stream not in chunk_buffer: + if all_downloaded.is_set(): + # All downloads done but chunk missing + raise Exception(f"Chunk {next_to_stream} never arrived") + stream_event.clear() + await asyncio.wait_for(stream_event.wait(), timeout=30) + + # Stream the chunk + chunk_data = chunk_buffer.pop(next_to_stream) + writer.write(chunk_data) + await writer.drain() + log.debug("Streamed chunk %d/%d", next_to_stream + 1, len(ranges)) + next_to_stream += 1 + + except Exception as e: + log.error("Streaming error: %s", e) + raise + + # Start downloads and streaming concurrently + download_tasks = [ + asyncio.create_task(fetch_chunk(i, s, e)) + for i, (s, e) in enumerate(ranges) + ] + stream_task = asyncio.create_task(stream_chunks()) + + # Wait for downloads to complete + try: + await asyncio.gather(*download_tasks) + all_downloaded.set() + stream_event.set() # Wake up streamer for final check + await stream_task + except Exception as e: + # Cancel remaining tasks + for task in download_tasks: + if not task.done(): + task.cancel() + if not stream_task.done(): + stream_task.cancel() + raise + + # Return empty bytes since we already streamed everything + return b"" + @staticmethod def _rewrite_206_to_200(raw: bytes) -> bytes: """Rewrite a 206 Partial Content response to 200 OK. diff --git a/src/proxy_server.py b/src/proxy_server.py index 7e17b52..b47f7a1 100644 --- a/src/proxy_server.py +++ b/src/proxy_server.py @@ -156,6 +156,14 @@ class ProxyServer: self.socks_port = config.get("socks5_port", 1080) self.fronter = DomainFronter(config) self.mitm = None + + # Chunked download settings (configurable) + exts = config.get("chunked_download_extensions", list(LARGE_FILE_EXTS)) + self._chunked_extensions = frozenset(exts) + self._chunked_bypass_check = ".*" in exts + self._chunked_min_size = config.get("chunked_download_min_size", 5 * 1024 * 1024) # 5MB default + self._chunked_chunk_size = config.get("chunked_download_chunk_size", 256 * 1024) # 256KB default + self._chunked_max_parallel = config.get("chunked_download_max_parallel", 16) # 16 parallel default self._cache = ResponseCache(max_mb=CACHE_MAX_MB) self._direct_fail_until: dict[str, float] = {} self._servers: list[asyncio.base_events.Server] = [] @@ -1000,7 +1008,7 @@ class ProxyServer: if response is None: # Relay through Apps Script try: - response = await self._relay_smart(method, url, headers, body) + response = await self._relay_smart(method, url, headers, body, writer) except Exception as e: log.error("Relay error (%s): %s", url[:60], e) err_body = f"Relay error: {e}".encode() @@ -1025,8 +1033,10 @@ class ProxyServer: self._log_response_summary(url, response) - writer.write(response) - await writer.drain() + # Only write if response not empty (streaming already sent data) + if response: + writer.write(response) + await writer.drain() except asyncio.TimeoutError: break @@ -1096,10 +1106,10 @@ class ProxyServer: additions.append("Vary: Origin") return ("\r\n".join(lines + additions) + "\r\n\r\n").encode() + body - async def _relay_smart(self, method, url, headers, body): + async def _relay_smart(self, method, url, headers, body, writer=None): """Choose optimal relay strategy based on request type. - - GET requests for likely-large downloads use parallel-range. + - GET requests for likely-large downloads use parallel-range with streaming. - All other requests (API calls, HTML, JSON, XHR) go through the single-request relay. This avoids injecting a synthetic Range header on normal traffic, which some origins honor by returning @@ -1116,19 +1126,37 @@ class ProxyServer: ) # Only probe with Range when the URL looks like a big file. if self._is_likely_download(url, headers): + # Use streaming version if writer provided + if writer: + return await self.fronter.relay_parallel_streaming( + method, url, headers, body, writer, + min_size=self._chunked_min_size, + chunk_size=self._chunked_chunk_size, + max_parallel=self._chunked_max_parallel + ) return await self.fronter.relay_parallel( - method, url, headers, body + method, url, headers, body, + chunk_size=self._chunked_chunk_size, + max_parallel=self._chunked_max_parallel ) return await self.fronter.relay(method, url, headers, body) def _is_likely_download(self, url: str, headers: dict) -> bool: - """Heuristic: is this URL likely a large file download?""" + """Heuristic: is this URL likely a large file download? + + If ".*" is in chunked_download_extensions, bypasses extension check + and returns True for all URLs (enables chunking for any file). + """ + if self._chunked_bypass_check: + return True + path = url.split("?")[0].lower() - for ext in LARGE_FILE_EXTS: - if path.endswith(ext): + for ext in self._chunked_extensions: + if path.endswith(ext.lower()): return True return False + # ── Plain HTTP forwarding ───────────────────────────────────── async def _do_http(self, header_block: bytes, reader, writer): @@ -1181,7 +1209,7 @@ class ProxyServer: log.debug("Cache HIT (HTTP): %s", url[:60]) if response is None: - response = await self._relay_smart(method, url, headers, body) + response = await self._relay_smart(method, url, headers, body, writer) # Cache successful GET if self._cache_allowed(method, url, headers, body) and response: ttl = ResponseCache.parse_ttl(response, url) @@ -1193,5 +1221,7 @@ class ProxyServer: response = self._inject_cors_headers(response, origin) self._log_response_summary(url, response) - writer.write(response) - await writer.drain() + # Only write if response not empty (streaming already sent data) + if response: + writer.write(response) + await writer.drain()