From afdd3e1036064316dec2eb62afddc88f3560697d Mon Sep 17 00:00:00 2001 From: PK3NZO Date: Thu, 23 Apr 2026 15:02:16 +0330 Subject: [PATCH] Improve relay stability and add streamed parallel downloads --- PR_DESCRIPTION.md | 140 +++++++++ README.md | 11 +- README_FA.md | 11 +- config.example.json | 43 ++- main.py | 2 +- setup.py | 19 +- src/constants.py | 1 + src/domain_fronter.py | 706 ++++++++++++++++++++++++++++++++++++++---- src/h2_transport.py | 47 ++- src/proxy_server.py | 313 ++++++++++++------- 10 files changed, 1101 insertions(+), 192 deletions(-) create mode 100644 PR_DESCRIPTION.md diff --git a/PR_DESCRIPTION.md b/PR_DESCRIPTION.md new file mode 100644 index 0000000..668c1fc --- /dev/null +++ b/PR_DESCRIPTION.md @@ -0,0 +1,140 @@ +# Suggested PR Title + +Improve relay stability and add streamed parallel downloads for large files + +## Summary + +This PR focuses on making the Apps Script relay safer, more stable, and significantly more usable for large downloads. + +The biggest change is a new streamed parallel download path for likely large files. Instead of buffering the entire file and only handing it to the browser at the end, the proxy can now start sending data to the client immediately while fetching the remaining ranges in parallel. This gives users real browser download progress and fixes the old behavior where the browser looked stuck on loading until the full file had already finished downloading in the background. + +Alongside that, this PR hardens request handling, improves shutdown behavior, adds safer defaults, reduces quota pressure for large downloads, and makes HTTP/2 fallback behavior much more defensive when the transport becomes unstable. + +## What Changed + +### Large-file downloads + +- Added a streamed parallel download path for likely large downloads +- Started sending response headers/body to the client immediately instead of waiting for full buffering +- Added disk-backed spooling for ordered chunk delivery during streaming downloads +- Added per-range validation using `Content-Range` and expected body length checks +- Added initial range-probe retries before falling back +- Added host cooldown for flaky streaming download targets so repeatedly failing hosts do not keep producing broken partial downloads +- Moved range-download traffic off the shared H2 relay path and onto the H1 pool to avoid destabilizing normal relay traffic +- Added progress logging in a more readable terminal format: + - elapsed time + - progress bar + - bytes downloaded / total + - current download speed +- Added `.bin` to default large-download extension heuristics + +### Stability and relay behavior + +- Added configurable relay and connect timeouts +- Added configurable response body cap enforcement across buffered relay paths +- Made retry behavior safer by limiting retries for non-idempotent requests +- Improved GET request coalescing by including representation-relevant headers in the coalescing key +- Added bounded chunk/request tuning for large downloads to reduce quota pressure +- Added safer defaults for: + - `chunked_download_chunk_size` + - `chunked_download_max_parallel` + - `chunked_download_max_chunks` +- Added explicit handling for unsupported request `Transfer-Encoding` + +### HTTP/2 hardening + +- Fixed H2 reconnect/reader lifecycle issues that could cause reconnect storms or stale-reader interference +- Added temporary H2 cooldown/circuit-breaker behavior after repeated consecutive H2 failures +- Reduced noisy close-notify behavior and improved connection lifecycle handling +- Prevented large parallel downloads from destabilizing the shared H2 relay path used by normal traffic + +### Security / defaults / config hygiene + +- Removed permissive CORS behavior that effectively bypassed browser CORS protections +- Changed `lan_sharing` to be opt-in by default +- Updated setup flow so LAN sharing is explicitly prompted instead of silently inheriting insecure defaults +- Synced docs and config examples with actual runtime behavior + +### Shutdown / cleanup + +- Improved shutdown cleanup so active client tasks are tracked, cancelled, and awaited during stop +- Reduced `Task was destroyed but it is pending` noise on normal shutdown + +## Config Additions / Changes + +Added or documented the following config options: + +- `relay_timeout` +- `tls_connect_timeout` +- `tcp_connect_timeout` +- `max_response_body_bytes` +- `chunked_download_extensions` +- `chunked_download_min_size` +- `chunked_download_chunk_size` +- `chunked_download_max_parallel` +- `chunked_download_max_chunks` + +Default download tuning was also made more conservative to improve stability. + +## Bugs Fixed + +- Large downloads appearing only after full buffering instead of progressing in-browser +- Browser download UX looking stalled while the proxy was still downloading in the background +- Frequent range chunk corruption/acceptance without validating `Content-Range` +- H2 reconnect/reader race behavior causing repeated `H2 reader loop ended` / `Connection lost` cascades +- Insecure default LAN exposure +- Dangerous CORS response injection behavior +- Inconsistent response size-cap enforcement +- Pending asyncio task noise during shutdown +- Silent mishandling of unsupported `Transfer-Encoding` requests + +## Manual Testing + +### Download tests + +- `https://ash-speed.hetzner.com/100MB.bin` + - Download completed successfully + - Browser showed progressive download behavior + - Terminal progress output updated correctly +- `https://fsn1-speed.hetzner.com/100MB.bin` + - Exposed host-specific flakiness during range streaming + - Added host cooldown / fallback protection after repeated streaming failures +- `https://fsn1-speed.hetzner.com/1GB.bin` + - Download completed successfully + - Completed in under 5 minutes in real-world testing + - Browser showed real incremental progress instead of waiting for full buffering + +### Runtime / startup / shutdown checks + +- Verified startup with the current config +- Verified graceful `Ctrl+C` shutdown behavior after cleanup changes +- Ran compile validation: + - `python3 -m compileall main.py src setup.py` + +## Result / Impact + +This PR substantially improves the user experience and resilience of the proxy: + +- Large downloads now behave like real downloads in the browser +- 100MB and 1GB files can complete successfully through the relay +- The system no longer feels stuck during large transfers +- Large-download traffic is isolated from the shared H2 relay path, reducing collateral instability +- Repeated H2 failures now degrade more safely instead of spamming reconnect errors +- Defaults are safer, especially for LAN exposure and browser security behavior +- Shutdown is cleaner and less noisy + +## Notes + +- Some hosts are still inherently flakier than others for parallel range downloading through Apps Script, so the host cooldown/fallback behavior is intentional and defensive +- This PR prioritizes correctness, safer degradation, and usable large-file behavior over forcing aggressive parallelism on every host + +## Screenshot + +Use the attached screenshot in the PR to show both browser-side progress and terminal-side progress: + +- Local file: `/Users/pouriarc/Downloads/photo_2026-04-23 14.55.20.jpeg` + +Recommended caption: + +> 1GB download progressing normally in the browser while the relay reports live chunk progress and throughput in the terminal. + diff --git a/README.md b/README.md index ceff8f3..5897f60 100644 --- a/README.md +++ b/README.md @@ -261,8 +261,17 @@ This project focuses entirely on the **Apps Script** relay — a free Google acc |---------|---------|-------------| | `google_ip` | `216.239.38.120` | Google IP address to connect through | | `front_domain` | `www.google.com` | Domain shown to the firewall/filter | -| `verify_ssl` | `true` | Verify TLS certificates | +| `verify_ssl` | `true` | Verify the TLS certificate on the local fronted connection to Google/CDN | +| `relay_timeout` | `25` | Total timeout for one relayed request before it fails | +| `tls_connect_timeout` | `15` | Timeout for the proxy's TLS connection to the fronted Google/CDN endpoint | +| `tcp_connect_timeout` | `10` | Timeout for direct TCP tunnels and outbound SNI-rewrite connects | +| `max_response_body_bytes` | `209715200` | Hard cap for a single relay response body after buffering/decoding | | `script_ids` | — | Multiple Script IDs for load balancing (array) | +| `chunked_download_extensions` | see [config.example.json](config.example.json) | File extensions that should use parallel range downloading. Supports `".*"` to probe all GET downloads. | +| `chunked_download_min_size` | `5242880` | Minimum total file size (5 MB) before range-parallel download stays enabled | +| `chunked_download_chunk_size` | `524288` | Per-range chunk size used by parallel downloads | +| `chunked_download_max_parallel` | `8` | Maximum simultaneous range requests for one download | +| `chunked_download_max_chunks` | `256` | Soft upper bound for total chunk requests; chunk size is raised automatically for very large files | | `block_hosts` | `[]` | Hosts that must never be tunneled (return HTTP 403). Supports exact names (`ads.example.com`) or leading-dot suffixes (`.doubleclick.net`). | | `bypass_hosts` | `["localhost", ".local", ".lan", ".home.arpa"]` | Hosts that go direct (no MITM, no relay). Useful for LAN resources or sites that break under MITM. | | `direct_google_exclude` | see [config.example.json](config.example.json) | Google apps that must use the MITM relay path instead of the fast direct tunnel. | diff --git a/README_FA.md b/README_FA.md index 5af49ad..692bac4 100644 --- a/README_FA.md +++ b/README_FA.md @@ -210,8 +210,17 @@ json |------|---------------|-------| | `google_ip` | `216.239.38.120` | IP مورد استفاده برای مسیر Google | | `front_domain` | `www.google.com` | دامنه‌ای که فیلتر می‌بیند | -| `verify_ssl` | `true` | بررسی اعتبار TLS | +| `verify_ssl` | `true` | بررسی اعتبار TLS فقط برای اتصال fronted محلی به Google/CDN | +| `relay_timeout` | `25` | مهلت کل برای هر درخواست relay قبل از fail شدن | +| `tls_connect_timeout` | `15` | مهلت اتصال TLS پروکسی به endpoint fronted روی Google/CDN | +| `tcp_connect_timeout` | `10` | مهلت اتصال برای tunnel مستقیم و SNI-rewrite | +| `max_response_body_bytes` | `209715200` | سقف نهایی برای اندازه body هر پاسخ relay بعد از buffer/decode | | `script_ids` | - | چند Deployment ID برای load balancing | +| `chunked_download_extensions` | مطابق [config.example.json](config.example.json) | پسوند فایل‌هایی که باید از دانلود range-parallel استفاده کنند. از `".*"` هم برای probe همه دانلودهای GET پشتیبانی می‌شود. | +| `chunked_download_min_size` | `5242880` | حداقل اندازه کل فایل (۵ مگابایت) برای فعال ماندن دانلود موازی | +| `chunked_download_chunk_size` | `524288` | اندازه هر chunk در دانلود موازی | +| `chunked_download_max_parallel` | `8` | حداکثر تعداد range request همزمان برای یک دانلود | +| `chunked_download_max_chunks` | `256` | سقف نرم برای تعداد کل chunk request ها؛ برای فایل‌های خیلی بزرگ اندازه chunk به‌صورت خودکار بیشتر می‌شود | | `block_hosts` | `[]` | هاست‌هایی که هرگز نباید tunnel شوند (پاسخ 403). نام دقیق (`ads.example.com`) یا پسوند با نقطه‌ی ابتدایی (`.doubleclick.net`). | | `bypass_hosts` | `["localhost", ".local", ".lan", ".home.arpa"]` | هاست‌هایی که مستقیم می‌روند (بدون MITM و بدون رله). برای منابع داخلی شبکه یا سایت‌هایی که با MITM مشکل دارند. | | `direct_google_exclude` | مراجعه به [config.example.json](config.example.json) | اپ‌های Google که باید از مسیر MITM برای رله استفاده کنند به‌جای tunnel مستقیم. | diff --git a/config.example.json b/config.example.json index 7c709ba..639aae7 100644 --- a/config.example.json +++ b/config.example.json @@ -10,8 +10,49 @@ "socks5_port": 1080, "log_level": "INFO", "verify_ssl": true, - "lan_sharing": true, + "lan_sharing": false, + "relay_timeout": 25, + "tls_connect_timeout": 15, + "tcp_connect_timeout": 10, + "max_response_body_bytes": 209715200, "parallel_relay": 1, + "chunked_download_extensions": [ + ".bin", + ".zip", + ".tar", + ".gz", + ".bz2", + ".xz", + ".7z", + ".rar", + ".exe", + ".msi", + ".dmg", + ".deb", + ".rpm", + ".apk", + ".iso", + ".img", + ".mp4", + ".mkv", + ".avi", + ".mov", + ".webm", + ".mp3", + ".flac", + ".wav", + ".aac", + ".pdf", + ".doc", + ".docx", + ".ppt", + ".pptx", + ".wasm" + ], + "chunked_download_min_size": 5242880, + "chunked_download_chunk_size": 524288, + "chunked_download_max_parallel": 8, + "chunked_download_max_chunks": 256, "block_hosts": [], "bypass_hosts": [ "localhost", diff --git a/main.py b/main.py index adad513..1aea240 100644 --- a/main.py +++ b/main.py @@ -230,7 +230,7 @@ def main(): log.info("MITM CA is already trusted.") # ── LAN sharing configuration ──────────────────────────────────────── - lan_sharing = config.get("lan_sharing", True) + lan_sharing = config.get("lan_sharing", False) if lan_sharing: # If LAN sharing is enabled and host is still localhost, change to all interfaces if config.get("listen_host", "127.0.0.1") == "127.0.0.1": diff --git a/setup.py b/setup.py index 7b0d78c..2f3166d 100644 --- a/setup.py +++ b/setup.py @@ -86,6 +86,15 @@ def load_base_config() -> dict: "socks5_port": 1080, "log_level": "INFO", "verify_ssl": True, + "lan_sharing": False, + "relay_timeout": 25, + "tls_connect_timeout": 15, + "tcp_connect_timeout": 10, + "max_response_body_bytes": 200 * 1024 * 1024, + "chunked_download_min_size": 5 * 1024 * 1024, + "chunked_download_chunk_size": 512 * 1024, + "chunked_download_max_parallel": 8, + "chunked_download_max_chunks": 256, "hosts": {}, } @@ -118,7 +127,15 @@ def configure_apps_script(cfg: dict) -> dict: def configure_network(cfg: dict) -> dict: print() print(bold("Network settings") + dim(" (press enter to accept defaults)")) - cfg["listen_host"] = prompt("Listen host", default=str(cfg.get("listen_host", "127.0.0.1"))) + cfg["lan_sharing"] = prompt_yes_no( + "Enable LAN sharing?", + default=bool(cfg.get("lan_sharing", False)), + ) + + default_host = str(cfg.get("listen_host", "127.0.0.1")) + if cfg["lan_sharing"] and default_host == "127.0.0.1": + default_host = "0.0.0.0" + cfg["listen_host"] = prompt("Listen host", default=default_host) port = prompt("HTTP proxy port", default=str(cfg.get("listen_port", 8085))) try: diff --git a/src/constants.py b/src/constants.py index c8851b9..5a1810c 100644 --- a/src/constants.py +++ b/src/constants.py @@ -165,6 +165,7 @@ STATIC_EXTS: tuple[str, ...] = ( ".mp3", ".mp4", ".webm", ".wasm", ".avif", ) LARGE_FILE_EXTS = frozenset({ + ".bin", ".zip", ".tar", ".gz", ".bz2", ".xz", ".7z", ".rar", ".exe", ".msi", ".dmg", ".deb", ".rpm", ".apk", ".iso", ".img", diff --git a/src/domain_fronter.py b/src/domain_fronter.py index 3464602..f2b1084 100644 --- a/src/domain_fronter.py +++ b/src/domain_fronter.py @@ -16,6 +16,7 @@ import logging import re import socket import ssl +import tempfile import time from dataclasses import dataclass from urllib.parse import urlparse @@ -27,6 +28,7 @@ from constants import ( BATCH_WINDOW_MICRO, CONN_TTL, FRONT_SNI_POOL_GOOGLE, + MAX_RESPONSE_BODY_BYTES, POOL_MAX, POOL_MIN_IDLE, RELAY_TIMEOUT, @@ -83,6 +85,18 @@ def _build_sni_pool(front_domain: str, overrides: list | None) -> list[str]: class DomainFronter: _STATIC_EXTS = STATIC_EXTS + _H2_FAILURE_COOLDOWN = 60.0 + _H2_FAILURE_THRESHOLD = 3 + _DOWNLOAD_STREAM_COOLDOWN = 300.0 + _COALESCE_VARY_HEADERS = ( + "accept", + "accept-language", + "user-agent", + "sec-fetch-dest", + "sec-fetch-mode", + "sec-fetch-site", + ) + _SAFE_RETRY_METHODS = {"GET", "HEAD", "OPTIONS"} def __init__(self, config: dict): self.connect_host = config.get("google_ip", "216.239.38.120") @@ -120,6 +134,16 @@ class DomainFronter: self.auth_key = config.get("auth_key", "") self.verify_ssl = config.get("verify_ssl", True) + self._relay_timeout = self._cfg_float( + config, "relay_timeout", RELAY_TIMEOUT, minimum=1.0, + ) + self._tls_connect_timeout = self._cfg_float( + config, "tls_connect_timeout", TLS_CONNECT_TIMEOUT, minimum=1.0, + ) + self._max_response_body_bytes = self._cfg_int( + config, "max_response_body_bytes", MAX_RESPONSE_BODY_BYTES, + minimum=1024, + ) # Connection pool — TTL-based, pre-warmed, with concurrency control self._pool: list[tuple[asyncio.StreamReader, asyncio.StreamWriter, float]] = [] @@ -146,6 +170,9 @@ class DomainFronter: # Request coalescing — dedup concurrent identical GETs self._coalesce: dict[str, list[asyncio.Future]] = {} + self._h2_failure_streak = 0 + self._h2_disabled_until = 0.0 + self._stream_download_disabled_until: dict[str, float] = {} # HTTP/2 multiplexing — one connection handles all requests self._h2 = None @@ -173,6 +200,23 @@ class DomainFronter: # ── helpers ─────────────────────────────────────────────────── + @staticmethod + def _cfg_int(config: dict, key: str, default: int, *, minimum: int = 1) -> int: + try: + value = int(config.get(key, default)) + except (TypeError, ValueError): + value = default + return max(minimum, value) + + @staticmethod + def _cfg_float(config: dict, key: str, default: float, + *, minimum: float = 0.1) -> float: + try: + value = float(config.get(key, default)) + except (TypeError, ValueError): + value = default + return max(minimum, value) + def _ssl_ctx(self) -> ssl.SSLContext: ctx = ssl.create_default_context() if not self.verify_ssl: @@ -180,6 +224,54 @@ class DomainFronter: ctx.verify_mode = ssl.CERT_NONE return ctx + def _h2_available(self) -> bool: + return ( + self._h2 is not None + and self._h2.is_connected + and time.time() >= self._h2_disabled_until + ) + + def _record_h2_success(self) -> None: + self._h2_failure_streak = 0 + + def _record_h2_failure(self, exc: Exception) -> None: + self._h2_failure_streak += 1 + if self._h2_failure_streak >= self._H2_FAILURE_THRESHOLD: + self._h2_disabled_until = time.time() + self._H2_FAILURE_COOLDOWN + log.warning( + "H2 temporarily disabled for %.0fs after %d consecutive failures (%s)", + self._H2_FAILURE_COOLDOWN, + self._h2_failure_streak, + type(exc).__name__, + ) + self._h2_failure_streak = 0 + + def _stream_download_allowed(self, url: str) -> bool: + host = self._host_key(url) + if not host: + return True + until = self._stream_download_disabled_until.get(host, 0.0) + if until > time.time(): + return False + if until: + self._stream_download_disabled_until.pop(host, None) + return True + + def _mark_stream_download_failure(self, url: str, reason: str) -> None: + host = self._host_key(url) + if not host: + return + self._stream_download_disabled_until[host] = ( + time.time() + self._DOWNLOAD_STREAM_COOLDOWN + ) + log.warning( + "Parallel streaming disabled for host %s for %.0fs after failure (%s)", + host, self._DOWNLOAD_STREAM_COOLDOWN, reason, + ) + + def stream_download_allowed(self, url: str) -> bool: + return self._stream_download_allowed(url) + async def _open(self): """Open a TLS connection to the CDN. @@ -228,7 +320,7 @@ class DomainFronter: except Exception: pass reader, writer = await asyncio.wait_for( - self._open(), timeout=TLS_CONNECT_TIMEOUT + self._open(), timeout=self._tls_connect_timeout ) # Pool was empty — trigger aggressive background refill if not self._refilling: @@ -327,6 +419,171 @@ class DomainFronter: host = parsed.hostname or url_or_host return host.lower().rstrip(".") + @classmethod + def _coalesce_key(cls, url: str, headers: dict | None) -> str: + key = [url] + if headers: + lowered = {str(k).lower(): str(v) for k, v in headers.items()} + for name in cls._COALESCE_VARY_HEADERS: + value = lowered.get(name) + if value: + key.append(f"{name}={value}") + return "\n".join(key) + + @classmethod + def _retry_attempts_for_payload(cls, payload: dict) -> int: + method = str(payload.get("m", "GET")).upper() + return 2 if method in cls._SAFE_RETRY_METHODS else 1 + + @staticmethod + def _render_streaming_headers(resp_headers: dict, total_size: int) -> bytes: + lines = ["HTTP/1.1 200 OK"] + skip = { + "transfer-encoding", + "connection", + "keep-alive", + "content-length", + "content-range", + } + for key, value in resp_headers.items(): + if key.lower() in skip: + continue + lines.append(f"{key}: {value}") + lines.append(f"Content-Length: {total_size}") + lines.append("") + lines.append("") + return "\r\n".join(lines).encode() + + @staticmethod + def _parse_content_range(value: str) -> tuple[int, int, int] | None: + match = re.match(r"^\s*bytes\s+(\d+)-(\d+)/(\d+)\s*$", value or "") + if not match: + return None + start, end, total = (int(group) for group in match.groups()) + if start < 0 or end < start or total <= end: + return None + return start, end, total + + @classmethod + def _validate_range_response(cls, status: int, resp_headers: dict, + body: bytes, start_off: int, + end_off: int, + total_size: int | None = None) -> str | None: + if status != 206: + return f"status {status}" + parsed = cls._parse_content_range(resp_headers.get("content-range", "")) + if not parsed: + return "missing/invalid Content-Range" + got_start, got_end, got_total = parsed + if got_start != start_off or got_end != end_off: + return f"Content-Range mismatch {got_start}-{got_end}" + if total_size is not None and got_total != total_size: + return f"Content-Range total mismatch {got_total}/{total_size}" + expected = end_off - start_off + 1 + if len(body) != expected: + return f"short chunk {len(body)}/{expected} B" + return None + + @staticmethod + def _spool_write(file_obj, offset: int, data: bytes) -> None: + file_obj.seek(offset) + file_obj.write(data) + file_obj.flush() + + @staticmethod + def _spool_read(file_obj, offset: int, size: int) -> bytes: + file_obj.seek(offset) + return file_obj.read(size) + + @staticmethod + def _format_bytes_human(num_bytes: int) -> str: + value = float(max(0, num_bytes)) + units = ("B", "KiB", "MiB", "GiB", "TiB") + unit = units[0] + for unit in units: + if value < 1024.0 or unit == units[-1]: + break + value /= 1024.0 + if unit == "B": + return f"{int(value)} {unit}" + return f"{value:.1f} {unit}" + + @staticmethod + def _format_elapsed_short(seconds: float) -> str: + total = max(0, int(seconds)) + minutes, secs = divmod(total, 60) + hours, minutes = divmod(minutes, 60) + if hours: + return f"{hours:02d}:{minutes:02d}:{secs:02d}" + return f"{minutes:02d}:{secs:02d}" + + @staticmethod + def _render_progress_bar(done: int, total: int, width: int = 34) -> str: + if total <= 0: + return "[" + ("-" * width) + "]" + ratio = max(0.0, min(1.0, done / total)) + filled = min(width, int(round(ratio * width))) + return "[" + ("#" * filled) + ("-" * (width - filled)) + "]" + + @classmethod + def _progress_line(cls, *, elapsed: float, done: int, total: int, + speed_bytes_per_sec: float) -> str: + return ( + f"[{cls._format_elapsed_short(elapsed)}] " + f"{cls._render_progress_bar(done, total)} " + f"{cls._format_bytes_human(done)} / {cls._format_bytes_human(total)} " + f"({cls._format_bytes_human(int(speed_bytes_per_sec))}/s)" + ) + + async def _relay_payload_h1(self, payload: dict) -> bytes: + attempts = self._retry_attempts_for_payload(payload) + async with self._semaphore: + for attempt in range(attempts): + try: + return await asyncio.wait_for( + self._relay_single(payload), timeout=self._relay_timeout, + ) + except Exception as exc: + if attempt < attempts - 1: + log.debug( + "H1 relay attempt %d failed (%s: %s), retrying", + attempt + 1, type(exc).__name__, exc, + ) + await self._flush_pool() + else: + raise + + async def _range_probe(self, url: str, headers: dict, start_off: int, + end_off: int, *, max_tries: int = 3) -> bytes: + probe_headers = dict(headers) if headers else {} + probe_headers["Range"] = f"bytes={start_off}-{end_off}" + probe_payload = self._build_payload("GET", url, probe_headers, b"") + last_raw = b"" + last_status = 0 + for attempt in range(max_tries): + try: + last_raw = await self._relay_payload_h1(probe_payload) + except Exception as exc: + if attempt == max_tries - 1: + raise + log.warning( + "Initial range probe %d-%d retry %d/%d failed: %r", + start_off, end_off, attempt + 1, max_tries, exc, + ) + await asyncio.sleep(0.3 * (attempt + 1)) + continue + + last_status, _, _ = self._split_raw_response(last_raw) + if last_status == 206 or last_status < 500: + return last_raw + if attempt < max_tries - 1: + log.warning( + "Initial range probe %d-%d retry %d/%d: status %d", + start_off, end_off, attempt + 1, max_tries, last_status, + ) + await asyncio.sleep(0.3 * (attempt + 1)) + return last_raw + # ── Per-host stats ──────────────────────────────────────────── def _record_site(self, url: str, bytes_: int, latency_ns: int, @@ -521,12 +778,17 @@ class DomainFronter: async def close(self): """Cancel background tasks and close all pooled / H2 connections.""" - for task in list(self._bg_tasks): + tasks = list(self._bg_tasks) + for task in tasks: task.cancel() - if self._bg_tasks: - self._spawn(self._prewarm_script()) - if self._keepalive_task is None or self._keepalive_task.done(): - self._keepalive_task = self._spawn + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + self._bg_tasks.clear() + + self._warm_task = None + self._maintenance_task = None + self._stats_task = None + self._keepalive_task = None await self._flush_pool() @@ -538,18 +800,25 @@ class DomainFronter: async def _h2_connect(self): """Connect the HTTP/2 transport in background.""" + if self._h2 is None: + return + if time.time() < self._h2_disabled_until: + return try: await self._h2.ensure_connected() + self._record_h2_success() log.info("H2 multiplexing active — one conn handles all requests") except Exception as e: + self._record_h2_failure(e) log.warning("H2 connect failed (%s), using H1 pool fallback", e) async def _h2_connect_and_warm(self): """Connect H2, pre-warm the Apps Script container, start keepalive.""" await self._h2_connect() - if self._h2 and self._h2.is_connected: - asyncio.create_task(self._prewarm_script()) - asyncio.create_task(self._keepalive_loop()) + if self._h2_available(): + self._spawn(self._prewarm_script()) + if self._keepalive_task is None or self._keepalive_task.done(): + self._keepalive_task = self._spawn(self._keepalive_loop()) async def _prewarm_script(self): """Pre-warm Apps Script and detect /dev fast path (no redirect).""" @@ -602,10 +871,12 @@ class DomainFronter: try: await asyncio.sleep(240) # 4 minutes — saves ~90 quota hits/day vs 180s # Google's container timeout is ~5 min idle - if not self._h2 or not self._h2.is_connected: + if not self._h2_available(): try: await self._h2.reconnect() - except Exception: + self._record_h2_success() + except Exception as exc: + self._record_h2_failure(exc) continue # H2 PING to keep connection alive @@ -681,7 +952,9 @@ class DomainFronter: has_range = True break if method == "GET" and not body and not has_range: - result = await self._coalesced_submit(url, payload) + result = await self._coalesced_submit( + self._coalesce_key(url, headers), payload, + ) return result result = await self._batch_submit(payload) @@ -693,7 +966,7 @@ class DomainFronter: latency_ns = int((time.perf_counter() - t0) * 1e9) self._record_site(url, len(result), latency_ns, errored) - async def _coalesced_submit(self, url: str, payload: dict) -> bytes: + async def _coalesced_submit(self, key: str, payload: dict) -> bytes: """Dedup concurrent requests for the same URL (no Range header). Uses `_batch_lock` to atomically check-and-append, preventing a @@ -702,14 +975,14 @@ class DomainFronter: """ loop = asyncio.get_event_loop() async with self._batch_lock: - waiters = self._coalesce.get(url) + waiters = self._coalesce.get(key) if waiters is not None: future = loop.create_future() waiters.append(future) - log.debug("Coalesced request: %s", url[:60]) + log.debug("Coalesced request: %s", key.split("\n", 1)[0][:60]) waiting = True else: - self._coalesce[url] = [] + self._coalesce[key] = [] waiting = False if waiting: @@ -719,14 +992,14 @@ class DomainFronter: result = await self._batch_submit(payload) except Exception as e: async with self._batch_lock: - waiters = self._coalesce.pop(url, []) + waiters = self._coalesce.pop(key, []) for f in waiters: if not f.done(): f.set_exception(e) raise async with self._batch_lock: - waiters = self._coalesce.pop(url, []) + waiters = self._coalesce.pop(key, []) for f in waiters: if not f.done(): f.set_result(result) @@ -734,8 +1007,10 @@ class DomainFronter: async def relay_parallel(self, method: str, url: str, headers: dict, body: bytes = b"", - chunk_size: int = 256 * 1024, - max_parallel: int = 16) -> bytes: + chunk_size: int = 512 * 1024, + max_parallel: int = 8, + max_chunks: int = 256, + min_size: int = 0) -> bytes: """Relay with parallel range acceleration for large downloads. Strategy: @@ -747,17 +1022,15 @@ class DomainFronter: Since each Apps Script call takes ~2s regardless of payload size, we use: - - 256 KB chunks (safe under Apps Script response limit) - - Up to 16 chunks in flight at once via H2 multiplexing + - 512 KB chunks (fewer relay calls, lower quota pressure) + - Up to 8 chunks in flight at once via H2 multiplexing - Aggregate throughput of ~2 MB per round-trip (~2-3s) """ if method != "GET" or body: return await self.relay(method, url, headers, body) # Probe: first chunk with Range header - range_headers = dict(headers) if headers else {} - range_headers["Range"] = f"bytes=0-{chunk_size - 1}" - first_resp = await self.relay("GET", url, range_headers, b"") + first_resp = await self._range_probe(url, headers, 0, chunk_size - 1) status, resp_hdrs, resp_body = self._split_raw_response(first_resp) @@ -768,13 +1041,40 @@ class DomainFronter: return first_resp # Parse total size from Content-Range: "bytes 0-262143/1048576" - content_range = resp_hdrs.get("content-range", "") - m = re.search(r"/(\d+)", content_range) - if not m: + parsed_range = self._parse_content_range(resp_hdrs.get("content-range", "")) + if not parsed_range: # Can't parse — downgrade to 200 so the client (which sent a # plain GET) doesn't get confused by 206 + Content-Range. return self._rewrite_206_to_200(first_resp) - total_size = int(m.group(1)) + first_start, first_end, total_size = parsed_range + first_err = self._validate_range_response( + status, resp_hdrs, resp_body, first_start, first_end, total_size, + ) + if first_start != 0 or first_err: + return self._rewrite_206_to_200(first_resp) + if total_size > self._max_response_body_bytes: + return self._error_response( + 502, + "Relay response exceeds cap " + f"({self._max_response_body_bytes} bytes). " + "Increase max_response_body_bytes if your system has enough RAM.", + ) + if min_size > 0 and total_size < min_size: + return self._rewrite_206_to_200(first_resp) + if max_chunks > 0: + required_chunk_size = max( + chunk_size, + (total_size + max_chunks - 1) // max_chunks, + ) + if required_chunk_size != chunk_size: + log.info( + "Parallel download tuning: chunk size raised from %d KB to %d KB " + "to keep request count under %d", + chunk_size // 1024, + required_chunk_size // 1024, + max_chunks, + ) + chunk_size = required_chunk_size # Small file: probe already fetched it all. MUST rewrite to 200 # because the client never sent a Range header — a stray 206 here @@ -795,22 +1095,54 @@ class DomainFronter: # Concurrency-limited parallel fetch sem = asyncio.Semaphore(max_parallel) + progress_lock = asyncio.Lock() + completed_chunks = 1 # first range probe already succeeded + completed_bytes = len(resp_body) + last_progress_log = time.perf_counter() + total_chunks = len(ranges) + 1 + total_bytes = total_size async def fetch_range(s, e, max_tries: int = 3): + nonlocal completed_chunks, completed_bytes, last_progress_log async with sem: rh_base = dict(headers) if headers else {} rh_base["Range"] = f"bytes={s}-{e}" + payload = self._build_payload("GET", url, rh_base, b"") expected = e - s + 1 last_err = None for attempt in range(max_tries): try: - raw = await self.relay("GET", url, rh_base, b"") - _, _, chunk_body = self._split_raw_response(raw) - if len(chunk_body) == expected: - return chunk_body - last_err = ( - f"short chunk {len(chunk_body)}/{expected} B" + raw = await self._relay_payload_h1(payload) + chunk_status, chunk_headers, chunk_body = self._split_raw_response(raw) + err = self._validate_range_response( + chunk_status, chunk_headers, chunk_body, + s, e, total_size, ) + if err is None: + now = time.perf_counter() + async with progress_lock: + completed_chunks += 1 + completed_bytes += len(chunk_body) + should_log = ( + completed_chunks == total_chunks + or (now - last_progress_log) >= 5.0 + ) + if should_log: + elapsed = max(0.001, now - t0) + speed_bps = completed_bytes / elapsed + log.info( + "Parallel download progress: %s [%d/%d chunks]", + self._progress_line( + elapsed=elapsed, + done=completed_bytes, + total=total_bytes, + speed_bytes_per_sec=speed_bps, + ), + completed_chunks, total_chunks, + ) + last_progress_log = now + return chunk_body + last_err = err except Exception as e_: last_err = repr(e_) log.warning("Range %d-%d retry %d/%d: %s", @@ -837,8 +1169,15 @@ class DomainFronter: full_body = b"".join(parts) kbs = (len(full_body) / 1024) / elapsed if elapsed > 0 else 0 - log.info("Parallel download complete: %d B in %.2fs = %.1f KB/s", - len(full_body), elapsed, kbs) + log.info( + "Parallel download complete: %s", + self._progress_line( + elapsed=elapsed, + done=len(full_body), + total=len(full_body), + speed_bytes_per_sec=kbs * 1024, + ), + ) # Return as 200 OK (client sent a normal GET) result = f"HTTP/1.1 200 OK\r\n" @@ -851,6 +1190,219 @@ class DomainFronter: result += "\r\n" return result.encode() + full_body + async def stream_parallel_download(self, url: str, headers: dict, + writer, + *, + chunk_size: int = 512 * 1024, + max_parallel: int = 8, + max_chunks: int = 256, + min_size: int = 0) -> bool: + """Stream a large range-capable download to the client incrementally. + + Returns False when the target should fall back to the normal relay + path (for example no range support or the file is too small). + Returns True once this method has taken ownership of the client + response, even if the stream later aborts. + """ + first_resp = await self._range_probe(url, headers, 0, chunk_size - 1) + + status, resp_hdrs, resp_body = self._split_raw_response(first_resp) + if status != 206: + log.info( + "Streaming download fallback: initial probe returned %s for %s", + status, url[:80], + ) + return False + + parsed_range = self._parse_content_range(resp_hdrs.get("content-range", "")) + if not parsed_range: + log.info( + "Streaming download fallback: missing/invalid Content-Range for %s", + url[:80], + ) + return False + first_start, first_end, total_size = parsed_range + first_err = self._validate_range_response( + status, resp_hdrs, resp_body, first_start, first_end, total_size, + ) + if first_start != 0 or first_err: + log.info( + "Streaming download fallback: invalid first range (%s) for %s", + first_err or f"start={first_start}", + url[:80], + ) + return False + if min_size > 0 and total_size < min_size: + log.info( + "Streaming download fallback: file too small (%d < %d) for %s", + total_size, min_size, url[:80], + ) + return False + if max_chunks > 0: + required_chunk_size = max( + chunk_size, + (total_size + max_chunks - 1) // max_chunks, + ) + if required_chunk_size != chunk_size: + log.info( + "Parallel download tuning: chunk size raised from %d KB to %d KB " + "to keep request count under %d", + chunk_size // 1024, + required_chunk_size // 1024, + max_chunks, + ) + chunk_size = required_chunk_size + + if total_size <= chunk_size or len(resp_body) >= total_size: + writer.write(self._render_streaming_headers(resp_hdrs, total_size)) + writer.write(resp_body) + await writer.drain() + return True + + ranges = [] + start = len(resp_body) + while start < total_size: + end = min(start + chunk_size - 1, total_size - 1) + ranges.append((start, end)) + start = end + 1 + + log.info("Parallel streaming download: %d bytes, %d chunks of %d KB", + total_size, len(ranges) + 1, chunk_size // 1024) + + temp_file = tempfile.TemporaryFile(prefix="mhrvpn_dl_") + file_lock = asyncio.Lock() + sem = asyncio.Semaphore(max_parallel) + cancel_event = asyncio.Event() + tasks: list[asyncio.Task] = [] + ready = [asyncio.Event() for _ in ranges] + errors: list[Exception | None] = [None for _ in ranges] + delivered_chunks = 1 + delivered_bytes = len(resp_body) + total_chunks = len(ranges) + 1 + last_progress_log = time.perf_counter() + t0 = time.perf_counter() + + async def _write_progress(force: bool = False) -> None: + nonlocal last_progress_log + now = time.perf_counter() + if not force and (now - last_progress_log) < 5.0: + return + elapsed = max(0.001, now - t0) + speed_bps = delivered_bytes / elapsed + log.info( + "Parallel download progress: %s [%d/%d chunks]", + self._progress_line( + elapsed=elapsed, + done=delivered_bytes, + total=total_size, + speed_bytes_per_sec=speed_bps, + ), + delivered_chunks, total_chunks, + ) + last_progress_log = now + + async def fetch_range(index: int, start_off: int, end_off: int, + max_tries: int = 3) -> None: + async with sem: + base_headers = dict(headers) if headers else {} + base_headers["Range"] = f"bytes={start_off}-{end_off}" + payload = self._build_payload("GET", url, base_headers, b"") + expected = end_off - start_off + 1 + last_err = "unknown" + try: + for attempt in range(max_tries): + if cancel_event.is_set(): + return + try: + raw = await self._relay_payload_h1(payload) + chunk_status, chunk_headers, chunk_body = self._split_raw_response(raw) + err = self._validate_range_response( + chunk_status, chunk_headers, chunk_body, + start_off, end_off, total_size, + ) + if err is None: + async with file_lock: + await asyncio.to_thread( + self._spool_write, temp_file, start_off, chunk_body, + ) + ready[index].set() + return + last_err = err + except Exception as exc: + last_err = repr(exc) + if cancel_event.is_set(): + return + log.warning("Range %d-%d retry %d/%d: %s", + start_off, end_off, attempt + 1, max_tries, last_err) + await asyncio.sleep(0.3 * (attempt + 1)) + errors[index] = RuntimeError( + f"chunk {start_off}-{end_off} failed after {max_tries} tries: {last_err}" + ) + ready[index].set() + except asyncio.CancelledError: + raise + + try: + writer.write(self._render_streaming_headers(resp_hdrs, total_size)) + writer.write(resp_body) + await writer.drain() + + for index, (start_off, end_off) in enumerate(ranges): + tasks.append(asyncio.create_task(fetch_range(index, start_off, end_off))) + + for index, (start_off, end_off) in enumerate(ranges): + await ready[index].wait() + if errors[index] is not None: + raise errors[index] + expected = end_off - start_off + 1 + async with file_lock: + chunk = await asyncio.to_thread( + self._spool_read, temp_file, start_off, expected, + ) + if len(chunk) != expected: + raise RuntimeError( + f"spooled chunk {start_off}-{end_off} was truncated " + f"({len(chunk)}/{expected} B)" + ) + writer.write(chunk) + await writer.drain() + delivered_chunks += 1 + delivered_bytes += len(chunk) + await _write_progress(force=(index == len(ranges) - 1)) + + elapsed = max(0.001, time.perf_counter() - t0) + log.info( + "Parallel streaming download complete: %s", + self._progress_line( + elapsed=elapsed, + done=total_size, + total=total_size, + speed_bytes_per_sec=total_size / elapsed, + ), + ) + return True + except (ConnectionError, BrokenPipeError, TimeoutError) as exc: + log.info("Parallel download cancelled by client: %s", exc) + cancel_event.set() + return True + except Exception as exc: + self._mark_stream_download_failure(url, str(exc)) + log.error("Parallel streaming download failed (%s): %s", url[:60], exc) + cancel_event.set() + try: + if not writer.is_closing(): + writer.close() + except Exception: + pass + return True + finally: + cancel_event.set() + for task in tasks: + task.cancel() + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + temp_file.close() + @staticmethod def _rewrite_206_to_200(raw: bytes) -> bytes: """Rewrite a 206 Partial Content response to 200 OK. @@ -1039,33 +1591,43 @@ class DomainFronter: async def _relay_with_retry(self, payload: dict) -> bytes: """Single relay with one retry on failure. Uses H2 if available.""" + attempts = self._retry_attempts_for_payload(payload) # Fan-out: race N Apps Script instances when enabled and H2 is up. # Cuts tail latency when one container is slow/cold. Only kicks in # if multiple script IDs are configured and the H2 transport is live. - if (self._parallel_relay > 1 + if (attempts > 1 + and self._parallel_relay > 1 and len(self._script_ids) > 1 - and self._h2 and self._h2.is_connected): + and self._h2_available()): try: - return await asyncio.wait_for( - self._relay_fanout(payload), timeout=RELAY_TIMEOUT, + result = await asyncio.wait_for( + self._relay_fanout(payload), timeout=self._relay_timeout, ) + self._record_h2_success() + return result except Exception as e: + self._record_h2_failure(e) log.debug("Fan-out relay failed (%s), falling back", e) # fall through to single-path logic below # Try HTTP/2 first — much faster (multiplexed, no pool checkout) - if self._h2 and self._h2.is_connected: - for attempt in range(2): + if self._h2_available(): + for attempt in range(attempts): try: - return await asyncio.wait_for( - self._relay_single_h2(payload), timeout=RELAY_TIMEOUT + result = await asyncio.wait_for( + self._relay_single_h2(payload), timeout=self._relay_timeout ) + self._record_h2_success() + return result except Exception as e: - if attempt == 0: + self._record_h2_failure(e) + if attempt < attempts - 1: log.debug("H2 relay failed (%s), reconnecting", e) try: await self._h2.reconnect() - except Exception: + self._record_h2_success() + except Exception as reconnect_exc: + self._record_h2_failure(reconnect_exc) log.warning("H2 reconnect failed, falling back to H1") break else: @@ -1073,14 +1635,15 @@ class DomainFronter: # HTTP/1.1 fallback (pool-based) async with self._semaphore: - for attempt in range(2): + for attempt in range(attempts): try: return await asyncio.wait_for( - self._relay_single(payload), timeout=RELAY_TIMEOUT + self._relay_single(payload), timeout=self._relay_timeout ) except Exception as e: - if attempt == 0: - log.debug("Relay attempt 1 failed (%s: %s), retrying", + if attempt < attempts - 1: + log.debug("Relay attempt %d failed (%s: %s), retrying", + attempt + 1, type(e).__name__, e) await self._flush_pool() else: @@ -1248,7 +1811,7 @@ class DomainFronter: path = self._exec_path(payloads[0].get("u") if payloads else None) # Try HTTP/2 first - if self._h2 and self._h2.is_connected: + if self._h2_available(): try: status, headers, body = await asyncio.wait_for( self._h2.request( @@ -1258,8 +1821,10 @@ class DomainFronter: ), timeout=30, ) + self._record_h2_success() return self._parse_batch_body(body, payloads) except Exception as e: + self._record_h2_failure(e) log.debug("H2 batch failed (%s), falling back to H1", e) # HTTP/1.1 fallback @@ -1383,7 +1948,13 @@ class DomainFronter: if "chunked" in transfer_encoding: body = await self._read_chunked(reader, body) elif content_length: - remaining = int(content_length) - len(body) + total = int(content_length) + if total > self._max_response_body_bytes: + raise RuntimeError( + "Relay response exceeds configured size cap " + f"({total} > {self._max_response_body_bytes} bytes)" + ) + remaining = total - len(body) while remaining > 0: chunk = await asyncio.wait_for( reader.read(min(remaining, 65536)), timeout=20 @@ -1391,6 +1962,10 @@ class DomainFronter: if not chunk: break body += chunk + if len(body) > self._max_response_body_bytes: + raise RuntimeError( + "Relay response exceeded configured size cap while reading body" + ) remaining -= len(chunk) else: # No framing — short timeout read (keep-alive safe) @@ -1400,6 +1975,10 @@ class DomainFronter: if not chunk: break body += chunk + if len(body) > self._max_response_body_bytes: + raise RuntimeError( + "Relay response exceeded configured size cap while streaming" + ) except asyncio.TimeoutError: break @@ -1407,13 +1986,17 @@ class DomainFronter: enc = headers.get("content-encoding", "") if enc: body = codec.decode(body, enc) + if len(body) > self._max_response_body_bytes: + raise RuntimeError( + "Decoded relay response exceeded configured size cap" + ) return status, headers, body async def _read_chunked(self, reader, buf=b""): """Incrementally read chunked transfer-encoding.""" result = b"" - _MAX_BODY = 200 * 1024 * 1024 # 200 MB total body cap + max_body = self._max_response_body_bytes while True: while b"\r\n" not in buf: data = await asyncio.wait_for(reader.read(8192), timeout=20) @@ -1433,9 +2016,11 @@ class DomainFronter: break if size == 0: break - if size > _MAX_BODY or len(result) + size > _MAX_BODY: - log.warning("Chunked body exceeds %d MB cap — truncating", _MAX_BODY // (1024 * 1024)) - break + if size > max_body or len(result) + size > max_body: + raise RuntimeError( + "Chunked relay response exceeded configured size cap " + f"({max_body} bytes)" + ) while len(buf) < size + 2: data = await asyncio.wait_for(reader.read(65536), timeout=20) @@ -1479,6 +2064,13 @@ class DomainFronter: status = data.get("s", 200) resp_headers = data.get("h", {}) resp_body = base64.b64decode(data.get("b", "")) + if len(resp_body) > self._max_response_body_bytes: + return self._error_response( + 502, + "Relay response exceeds cap " + f"({self._max_response_body_bytes} bytes). " + "Increase max_response_body_bytes if your system has enough RAM.", + ) status_text = {200: "OK", 206: "Partial Content", 301: "Moved", 302: "Found", 304: "Not Modified", diff --git a/src/h2_transport.py b/src/h2_transport.py index 6eb3d85..edeb1d1 100644 --- a/src/h2_transport.py +++ b/src/h2_transport.py @@ -80,6 +80,7 @@ class H2Transport: self._write_lock = asyncio.Lock() self._connect_lock = asyncio.Lock() self._read_task: asyncio.Task | None = None + self._conn_generation = 0 # Per-stream tracking self._streams: dict[int, _StreamState] = {} @@ -174,26 +175,34 @@ class H2Transport: await self._flush() self._connected = True - self._read_task = asyncio.create_task(self._reader_loop()) + self._conn_generation += 1 + generation = self._conn_generation + self._read_task = asyncio.create_task(self._reader_loop(generation)) log.info("H2 connected → %s (SNI=%s, TCP_NODELAY=on)", self.connect_host, sni) async def reconnect(self): """Close current connection and re-establish.""" - await self._close_internal() - await self._do_connect() + async with self._connect_lock: + await self._close_internal() + await self._do_connect() async def _close_internal(self): self._connected = False - if self._read_task: - self._read_task.cancel() - self._read_task = None + read_task = self._read_task + self._read_task = None + if read_task: + read_task.cancel() + await asyncio.gather(read_task, return_exceptions=True) if self._writer: try: - self._writer.close() + writer = self._writer + self._writer = None + writer.close() + await writer.wait_closed() except Exception: pass - self._writer = None + self._reader = None # Wake all pending streams so they can raise for state in self._streams.values(): state.error = "Connection closed" @@ -327,7 +336,7 @@ class H2Transport: # ── Background reader ───────────────────────────────────────── - async def _reader_loop(self): + async def _reader_loop(self, generation: int): """Background: read H2 frames, dispatch events to waiting streams.""" try: while self._connected: @@ -352,14 +361,20 @@ class H2Transport: except asyncio.CancelledError: pass except Exception as e: - log.error("H2 reader error: %s", e) + if "application data after close notify" in str(e).lower(): + log.debug("H2 reader closed after close_notify: %s", e) + else: + log.error("H2 reader error: %s", e) finally: - self._connected = False - for state in self._streams.values(): - if not state.done.is_set(): - state.error = "Connection lost" - state.done.set() - log.info("H2 reader loop ended") + if generation != self._conn_generation: + log.debug("H2 reader loop ended for stale generation %d", generation) + else: + self._connected = False + for state in self._streams.values(): + if not state.done.is_set(): + state.error = "Connection lost" + state.done.set() + log.info("H2 reader loop ended") def _dispatch(self, event): """Route a single h2 event to its stream.""" diff --git a/src/proxy_server.py b/src/proxy_server.py index 93416eb..8d3f6d3 100644 --- a/src/proxy_server.py +++ b/src/proxy_server.py @@ -65,6 +65,23 @@ def _parse_content_length(header_block: bytes) -> int: return 0 +def _has_unsupported_transfer_encoding(header_block: bytes) -> bool: + """True when the request uses Transfer-Encoding, which we don't stream.""" + for raw_line in header_block.split(b"\r\n"): + name, sep, value = raw_line.partition(b":") + if not sep: + continue + if name.strip().lower() != b"transfer-encoding": + continue + encodings = [ + token.strip().lower() + for token in value.decode(errors="replace").split(",") + if token.strip() + ] + return any(token != "identity" for token in encodings) + return False + + class ResponseCache: """Simple LRU response cache — avoids repeated relay calls.""" @@ -147,6 +164,14 @@ class ProxyServer: _GOOGLE_DIRECT_ALLOW_EXACT = GOOGLE_DIRECT_ALLOW_EXACT _GOOGLE_DIRECT_ALLOW_SUFFIXES = GOOGLE_DIRECT_ALLOW_SUFFIXES _TRACE_HOST_SUFFIXES = TRACE_HOST_SUFFIXES + _DOWNLOAD_DEFAULT_EXTS = tuple(sorted(LARGE_FILE_EXTS)) + _DOWNLOAD_ACCEPT_MARKERS = ( + "application/octet-stream", + "application/zip", + "application/x-bittorrent", + "video/", + "audio/", + ) def __init__(self, config: dict): self.host = config.get("listen_host", "127.0.0.1") @@ -159,6 +184,30 @@ class ProxyServer: self._cache = ResponseCache(max_mb=CACHE_MAX_MB) self._direct_fail_until: dict[str, float] = {} self._servers: list[asyncio.base_events.Server] = [] + self._client_tasks: set[asyncio.Task] = set() + self._tcp_connect_timeout = self._cfg_float( + config, "tcp_connect_timeout", TCP_CONNECT_TIMEOUT, minimum=1.0, + ) + self._download_min_size = self._cfg_int( + config, "chunked_download_min_size", 5 * 1024 * 1024, minimum=0, + ) + self._download_chunk_size = self._cfg_int( + config, "chunked_download_chunk_size", 512 * 1024, minimum=64 * 1024, + ) + self._download_max_parallel = self._cfg_int( + config, "chunked_download_max_parallel", 8, minimum=1, + ) + self._download_max_chunks = self._cfg_int( + config, "chunked_download_max_chunks", 256, minimum=1, + ) + self._download_extensions, self._download_any_extension = ( + self._normalize_download_extensions( + config.get( + "chunked_download_extensions", + list(self._DOWNLOAD_DEFAULT_EXTS), + ) + ) + ) # hosts override — DNS fake-map: domain/suffix → IP # Checked before any real DNS lookup; supports exact and suffix matching. @@ -198,6 +247,55 @@ class ProxyServer: # ── Host-policy helpers ─────────────────────────────────────── + @staticmethod + def _cfg_int(config: dict, key: str, default: int, *, minimum: int = 1) -> int: + try: + value = int(config.get(key, default)) + except (TypeError, ValueError): + value = default + return max(minimum, value) + + @staticmethod + def _cfg_float(config: dict, key: str, default: float, + *, minimum: float = 0.1) -> float: + try: + value = float(config.get(key, default)) + except (TypeError, ValueError): + value = default + return max(minimum, value) + + @classmethod + def _normalize_download_extensions(cls, raw) -> tuple[tuple[str, ...], bool]: + values = raw if isinstance(raw, (list, tuple)) else cls._DOWNLOAD_DEFAULT_EXTS + normalized: list[str] = [] + any_extension = False + seen: set[str] = set() + for item in values: + ext = str(item).strip().lower() + if not ext: + continue + if ext in {"*", ".*"}: + any_extension = True + continue + if not ext.startswith("."): + ext = "." + ext + if ext not in seen: + seen.add(ext) + normalized.append(ext) + if not normalized and not any_extension: + normalized = list(cls._DOWNLOAD_DEFAULT_EXTS) + return tuple(normalized), any_extension + + def _track_current_task(self) -> asyncio.Task | None: + task = asyncio.current_task() + if task is not None: + self._client_tasks.add(task) + return task + + def _untrack_task(self, task: asyncio.Task | None) -> None: + if task is not None: + self._client_tasks.discard(task) + @staticmethod def _load_host_rules(raw) -> tuple[set[str], tuple[str, ...]]: """Accept a list of host strings; return (exact_set, suffix_tuple). @@ -352,15 +450,18 @@ class ProxyServer: self.socks_host, self.socks_port, ) - async with http_srv: - if socks_srv: - async with socks_srv: - await asyncio.gather( - http_srv.serve_forever(), - socks_srv.serve_forever(), - ) - else: - await http_srv.serve_forever() + try: + async with http_srv: + if socks_srv: + async with socks_srv: + await asyncio.gather( + http_srv.serve_forever(), + socks_srv.serve_forever(), + ) + else: + await http_srv.serve_forever() + except asyncio.CancelledError: + raise async def stop(self): """Shut down all listeners and release relay resources.""" @@ -375,6 +476,15 @@ class ProxyServer: except Exception: pass self._servers = [] + + current = asyncio.current_task() + client_tasks = [task for task in self._client_tasks if task is not current] + for task in client_tasks: + task.cancel() + if client_tasks: + await asyncio.gather(*client_tasks, return_exceptions=True) + self._client_tasks.clear() + try: await self.fronter.close() except Exception as exc: @@ -384,6 +494,7 @@ class ProxyServer: async def _on_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): addr = writer.get_extra_info("peername") + task = self._track_current_task() try: first_line = await asyncio.wait_for(reader.readline(), timeout=30) if not first_line: @@ -400,6 +511,16 @@ class ProxyServer: if line in (b"\r\n", b"\n", b""): break + if _has_unsupported_transfer_encoding(header_block): + log.warning("Unsupported Transfer-Encoding on client request") + writer.write( + b"HTTP/1.1 501 Not Implemented\r\n" + b"Connection: close\r\n" + b"Content-Length: 0\r\n\r\n" + ) + await writer.drain() + return + request_line = first_line.decode(errors="replace").strip() parts = request_line.split(" ", 2) if len(parts) < 2: @@ -412,11 +533,14 @@ class ProxyServer: else: await self._do_http(header_block, reader, writer) + except asyncio.CancelledError: + pass except asyncio.TimeoutError: log.debug("Timeout: %s", addr) except Exception as e: log.error("Error (%s): %s", addr, e) finally: + self._untrack_task(task) try: writer.close() await writer.wait_closed() @@ -426,6 +550,7 @@ class ProxyServer: async def _on_socks_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): addr = writer.get_extra_info("peername") + task = self._track_current_task() try: header = await asyncio.wait_for(reader.readexactly(2), timeout=15) ver, nmethods = header[0], header[1] @@ -475,11 +600,14 @@ class ProxyServer: except asyncio.IncompleteReadError: pass + except asyncio.CancelledError: + pass except asyncio.TimeoutError: log.debug("SOCKS5 timeout: %s", addr) except Exception as e: log.error("SOCKS5 error (%s): %s", addr, e) finally: + self._untrack_task(task) try: writer.close() await writer.wait_closed() @@ -783,7 +911,9 @@ class ProxyServer: """ target_ip = connect_ip or host try: - r_remote, w_remote = await self._open_tcp_connection(target_ip, port, timeout=10) + r_remote, w_remote = await self._open_tcp_connection( + target_ip, port, timeout=self._tcp_connect_timeout, + ) except Exception as e: log.error("Direct tunnel connect failed (%s via %s): %s", host, target_ip, e) @@ -858,7 +988,7 @@ class ProxyServer: ssl=ssl_ctx_client, server_hostname=sni_out, ), - timeout=10, + timeout=self._tcp_connect_timeout, ) except Exception as e: log.error("SNI-rewrite outbound connect failed (%s via %s): %s", @@ -969,6 +1099,15 @@ class ProxyServer: # Read body body = b"" + if _has_unsupported_transfer_encoding(header_block): + log.warning("Unsupported Transfer-Encoding → %s:%d", host, port) + writer.write( + b"HTTP/1.1 501 Not Implemented\r\n" + b"Connection: close\r\n" + b"Content-Length: 0\r\n\r\n" + ) + await writer.drain() + break length = _parse_content_length(header_block) if length > MAX_REQUEST_BODY_BYTES: raise ValueError(f"Request body too large: {length} bytes") @@ -1008,25 +1147,7 @@ class ProxyServer: log.info("MITM → %s %s", method, url) - # ── CORS: extract relevant request headers ──────────────────── - origin = next( - (v for k, v in headers.items() if k.lower() == "origin"), "" - ) - acr_method = next( - (v for k, v in headers.items() - if k.lower() == "access-control-request-method"), "" - ) - acr_headers = next( - (v for k, v in headers.items() - if k.lower() == "access-control-request-headers"), "" - ) - - # CORS preflight — respond directly; UrlFetchApp doesn't - # support OPTIONS so forwarding it would always fail. - if method.upper() == "OPTIONS" and acr_method: - log.debug("CORS preflight → %s (responding locally)", url[:60]) - writer.write(self._cors_preflight_response(origin, acr_method, acr_headers)) - await writer.drain() + if await self._maybe_stream_download(method, url, headers, body, writer): continue # Check local cache first (GET only) @@ -1057,11 +1178,6 @@ class ProxyServer: self._cache.put(url, response, ttl) log.debug("Cached (%ds): %s", ttl, url[:60]) - # Inject permissive CORS headers whenever the browser - # sent an Origin (cross-origin XHR / fetch). - if origin and response: - response = self._inject_cors_headers(response, origin) - self._log_response_summary(url, response) writer.write(response) @@ -1077,64 +1193,6 @@ class ProxyServer: log.error("MITM handler error (%s): %s", host, e) break - # ── CORS helpers ────────────────────────────────────────────────────────── - - @staticmethod - def _cors_preflight_response(origin: str, acr_method: str, acr_headers: str) -> bytes: - """Return a 204 No Content response that satisfies a CORS preflight.""" - allow_origin = origin or "*" - allow_methods = ( - f"{acr_method}, GET, POST, PUT, DELETE, PATCH, OPTIONS" - if acr_method else - "GET, POST, PUT, DELETE, PATCH, OPTIONS" - ) - allow_headers = acr_headers or "*" - return ( - "HTTP/1.1 204 No Content\r\n" - f"Access-Control-Allow-Origin: {allow_origin}\r\n" - f"Access-Control-Allow-Methods: {allow_methods}\r\n" - f"Access-Control-Allow-Headers: {allow_headers}\r\n" - "Access-Control-Allow-Credentials: true\r\n" - "Access-Control-Max-Age: 86400\r\n" - "Vary: Origin\r\n" - "Content-Length: 0\r\n" - "\r\n" - ).encode() - - @staticmethod - def _inject_cors_headers(response: bytes, origin: str) -> bytes: - """Inject CORS headers only if the upstream response lacks them. - - We must NOT overwrite the origin server's CORS headers: sites like - x.com return carefully-scoped Access-Control-Allow-Headers that list - specific custom headers (e.g. x-csrf-token). Replacing them with - wildcards together with Allow-Credentials: true makes browsers - reject the response (per the Fetch spec, "*" is literal when - credentials are included), which the site then blames on privacy - extensions. So we only fill in what the server omitted. - """ - sep = b"\r\n\r\n" - if sep not in response: - return response - header_section, body = response.split(sep, 1) - lines = header_section.decode(errors="replace").split("\r\n") - - existing = {ln.split(":", 1)[0].strip().lower() - for ln in lines if ":" in ln} - - # If the upstream already handled CORS, leave it completely alone. - if "access-control-allow-origin" in existing: - return response - - # Otherwise inject a minimal, credential-safe set (no wildcards, - # since wildcards combined with credentials are invalid). - allow_origin = origin or "*" - additions = [f"Access-Control-Allow-Origin: {allow_origin}"] - if allow_origin != "*": - additions.append("Access-Control-Allow-Credentials: true") - additions.append("Vary: Origin") - return ("\r\n".join(lines + additions) + "\r\n\r\n").encode() + body - async def _relay_smart(self, method, url, headers, body): """Choose optimal relay strategy based on request type. @@ -1156,22 +1214,67 @@ class ProxyServer: # Only probe with Range when the URL looks like a big file. if self._is_likely_download(url, headers): return await self.fronter.relay_parallel( - method, url, headers, body + method, + url, + headers, + body, + chunk_size=self._download_chunk_size, + max_parallel=self._download_max_parallel, + max_chunks=self._download_max_chunks, + min_size=self._download_min_size, ) return await self.fronter.relay(method, url, headers, body) def _is_likely_download(self, url: str, headers: dict) -> bool: """Heuristic: is this URL likely a large file download?""" path = url.split("?")[0].lower() - for ext in LARGE_FILE_EXTS: + if self._download_any_extension: + return True + for ext in self._download_extensions: if path.endswith(ext): return True + accept = self._header_value(headers, "accept").lower() + if any(marker in accept for marker in self._DOWNLOAD_ACCEPT_MARKERS): + return True return False + async def _maybe_stream_download(self, method: str, url: str, + headers: dict | None, body: bytes, + writer) -> bool: + if method.upper() != "GET" or body: + return False + if headers: + for key in headers: + if key.lower() == "range": + return False + effective_headers = headers or {} + if not self._is_likely_download(url, effective_headers): + return False + if not self.fronter.stream_download_allowed(url): + return False + return await self.fronter.stream_parallel_download( + url, + effective_headers, + writer, + chunk_size=self._download_chunk_size, + max_parallel=self._download_max_parallel, + max_chunks=self._download_max_chunks, + min_size=self._download_min_size, + ) + # ── Plain HTTP forwarding ───────────────────────────────────── async def _do_http(self, header_block: bytes, reader, writer): body = b"" + if _has_unsupported_transfer_encoding(header_block): + log.warning("Unsupported Transfer-Encoding on plain HTTP request") + writer.write( + b"HTTP/1.1 501 Not Implemented\r\n" + b"Connection: close\r\n" + b"Content-Length: 0\r\n\r\n" + ) + await writer.drain() + return length = _parse_content_length(header_block) if length > MAX_REQUEST_BODY_BYTES: writer.write(b"HTTP/1.1 413 Content Too Large\r\n\r\n") @@ -1194,22 +1297,7 @@ class ProxyServer: k, v = raw_line.decode(errors="replace").split(":", 1) headers[k.strip()] = v.strip() - # ── CORS preflight over plain HTTP ──────────────────────────── - origin = next( - (v for k, v in headers.items() if k.lower() == "origin"), "" - ) - acr_method = next( - (v for k, v in headers.items() - if k.lower() == "access-control-request-method"), "" - ) - acr_headers_val = next( - (v for k, v in headers.items() - if k.lower() == "access-control-request-headers"), "" - ) - if method.upper() == "OPTIONS" and acr_method: - log.debug("CORS preflight (HTTP) → %s (responding locally)", url[:60]) - writer.write(self._cors_preflight_response(origin, acr_method, acr_headers_val)) - await writer.drain() + if await self._maybe_stream_download(method, url, headers, body, writer): return # Cache check for GET @@ -1227,9 +1315,6 @@ class ProxyServer: if ttl > 0: self._cache.put(url, response, ttl) - # Inject CORS headers for cross-origin requests - if origin and response: - response = self._inject_cors_headers(response, origin) self._log_response_summary(url, response) writer.write(response)