Improve relay stability and add streamed parallel downloads

This commit is contained in:
PK3NZO
2026-04-23 15:02:16 +03:30
parent 57738ec5c8
commit afdd3e1036
10 changed files with 1101 additions and 192 deletions
+140
View File
@@ -0,0 +1,140 @@
# Suggested PR Title
Improve relay stability and add streamed parallel downloads for large files
## Summary
This PR focuses on making the Apps Script relay safer, more stable, and significantly more usable for large downloads.
The biggest change is a new streamed parallel download path for likely large files. Instead of buffering the entire file and only handing it to the browser at the end, the proxy can now start sending data to the client immediately while fetching the remaining ranges in parallel. This gives users real browser download progress and fixes the old behavior where the browser looked stuck on loading until the full file had already finished downloading in the background.
Alongside that, this PR hardens request handling, improves shutdown behavior, adds safer defaults, reduces quota pressure for large downloads, and makes HTTP/2 fallback behavior much more defensive when the transport becomes unstable.
## What Changed
### Large-file downloads
- Added a streamed parallel download path for likely large downloads
- Started sending response headers/body to the client immediately instead of waiting for full buffering
- Added disk-backed spooling for ordered chunk delivery during streaming downloads
- Added per-range validation using `Content-Range` and expected body length checks
- Added initial range-probe retries before falling back
- Added host cooldown for flaky streaming download targets so repeatedly failing hosts do not keep producing broken partial downloads
- Moved range-download traffic off the shared H2 relay path and onto the H1 pool to avoid destabilizing normal relay traffic
- Added progress logging in a more readable terminal format:
- elapsed time
- progress bar
- bytes downloaded / total
- current download speed
- Added `.bin` to default large-download extension heuristics
### Stability and relay behavior
- Added configurable relay and connect timeouts
- Added configurable response body cap enforcement across buffered relay paths
- Made retry behavior safer by limiting retries for non-idempotent requests
- Improved GET request coalescing by including representation-relevant headers in the coalescing key
- Added bounded chunk/request tuning for large downloads to reduce quota pressure
- Added safer defaults for:
- `chunked_download_chunk_size`
- `chunked_download_max_parallel`
- `chunked_download_max_chunks`
- Added explicit handling for unsupported request `Transfer-Encoding`
### HTTP/2 hardening
- Fixed H2 reconnect/reader lifecycle issues that could cause reconnect storms or stale-reader interference
- Added temporary H2 cooldown/circuit-breaker behavior after repeated consecutive H2 failures
- Reduced noisy close-notify behavior and improved connection lifecycle handling
- Prevented large parallel downloads from destabilizing the shared H2 relay path used by normal traffic
### Security / defaults / config hygiene
- Removed permissive CORS behavior that effectively bypassed browser CORS protections
- Changed `lan_sharing` to be opt-in by default
- Updated setup flow so LAN sharing is explicitly prompted instead of silently inheriting insecure defaults
- Synced docs and config examples with actual runtime behavior
### Shutdown / cleanup
- Improved shutdown cleanup so active client tasks are tracked, cancelled, and awaited during stop
- Reduced `Task was destroyed but it is pending` noise on normal shutdown
## Config Additions / Changes
Added or documented the following config options:
- `relay_timeout`
- `tls_connect_timeout`
- `tcp_connect_timeout`
- `max_response_body_bytes`
- `chunked_download_extensions`
- `chunked_download_min_size`
- `chunked_download_chunk_size`
- `chunked_download_max_parallel`
- `chunked_download_max_chunks`
Default download tuning was also made more conservative to improve stability.
## Bugs Fixed
- Large downloads appearing only after full buffering instead of progressing in-browser
- Browser download UX looking stalled while the proxy was still downloading in the background
- Frequent range chunk corruption/acceptance without validating `Content-Range`
- H2 reconnect/reader race behavior causing repeated `H2 reader loop ended` / `Connection lost` cascades
- Insecure default LAN exposure
- Dangerous CORS response injection behavior
- Inconsistent response size-cap enforcement
- Pending asyncio task noise during shutdown
- Silent mishandling of unsupported `Transfer-Encoding` requests
## Manual Testing
### Download tests
- `https://ash-speed.hetzner.com/100MB.bin`
- Download completed successfully
- Browser showed progressive download behavior
- Terminal progress output updated correctly
- `https://fsn1-speed.hetzner.com/100MB.bin`
- Exposed host-specific flakiness during range streaming
- Added host cooldown / fallback protection after repeated streaming failures
- `https://fsn1-speed.hetzner.com/1GB.bin`
- Download completed successfully
- Completed in under 5 minutes in real-world testing
- Browser showed real incremental progress instead of waiting for full buffering
### Runtime / startup / shutdown checks
- Verified startup with the current config
- Verified graceful `Ctrl+C` shutdown behavior after cleanup changes
- Ran compile validation:
- `python3 -m compileall main.py src setup.py`
## Result / Impact
This PR substantially improves the user experience and resilience of the proxy:
- Large downloads now behave like real downloads in the browser
- 100MB and 1GB files can complete successfully through the relay
- The system no longer feels stuck during large transfers
- Large-download traffic is isolated from the shared H2 relay path, reducing collateral instability
- Repeated H2 failures now degrade more safely instead of spamming reconnect errors
- Defaults are safer, especially for LAN exposure and browser security behavior
- Shutdown is cleaner and less noisy
## Notes
- Some hosts are still inherently flakier than others for parallel range downloading through Apps Script, so the host cooldown/fallback behavior is intentional and defensive
- This PR prioritizes correctness, safer degradation, and usable large-file behavior over forcing aggressive parallelism on every host
## Screenshot
Use the attached screenshot in the PR to show both browser-side progress and terminal-side progress:
- Local file: `/Users/pouriarc/Downloads/photo_2026-04-23 14.55.20.jpeg`
Recommended caption:
> 1GB download progressing normally in the browser while the relay reports live chunk progress and throughput in the terminal.
+10 -1
View File
@@ -261,8 +261,17 @@ This project focuses entirely on the **Apps Script** relay — a free Google acc
|---------|---------|-------------| |---------|---------|-------------|
| `google_ip` | `216.239.38.120` | Google IP address to connect through | | `google_ip` | `216.239.38.120` | Google IP address to connect through |
| `front_domain` | `www.google.com` | Domain shown to the firewall/filter | | `front_domain` | `www.google.com` | Domain shown to the firewall/filter |
| `verify_ssl` | `true` | Verify TLS certificates | | `verify_ssl` | `true` | Verify the TLS certificate on the local fronted connection to Google/CDN |
| `relay_timeout` | `25` | Total timeout for one relayed request before it fails |
| `tls_connect_timeout` | `15` | Timeout for the proxy's TLS connection to the fronted Google/CDN endpoint |
| `tcp_connect_timeout` | `10` | Timeout for direct TCP tunnels and outbound SNI-rewrite connects |
| `max_response_body_bytes` | `209715200` | Hard cap for a single relay response body after buffering/decoding |
| `script_ids` | — | Multiple Script IDs for load balancing (array) | | `script_ids` | — | Multiple Script IDs for load balancing (array) |
| `chunked_download_extensions` | see [config.example.json](config.example.json) | File extensions that should use parallel range downloading. Supports `".*"` to probe all GET downloads. |
| `chunked_download_min_size` | `5242880` | Minimum total file size (5 MB) before range-parallel download stays enabled |
| `chunked_download_chunk_size` | `524288` | Per-range chunk size used by parallel downloads |
| `chunked_download_max_parallel` | `8` | Maximum simultaneous range requests for one download |
| `chunked_download_max_chunks` | `256` | Soft upper bound for total chunk requests; chunk size is raised automatically for very large files |
| `block_hosts` | `[]` | Hosts that must never be tunneled (return HTTP 403). Supports exact names (`ads.example.com`) or leading-dot suffixes (`.doubleclick.net`). | | `block_hosts` | `[]` | Hosts that must never be tunneled (return HTTP 403). Supports exact names (`ads.example.com`) or leading-dot suffixes (`.doubleclick.net`). |
| `bypass_hosts` | `["localhost", ".local", ".lan", ".home.arpa"]` | Hosts that go direct (no MITM, no relay). Useful for LAN resources or sites that break under MITM. | | `bypass_hosts` | `["localhost", ".local", ".lan", ".home.arpa"]` | Hosts that go direct (no MITM, no relay). Useful for LAN resources or sites that break under MITM. |
| `direct_google_exclude` | see [config.example.json](config.example.json) | Google apps that must use the MITM relay path instead of the fast direct tunnel. | | `direct_google_exclude` | see [config.example.json](config.example.json) | Google apps that must use the MITM relay path instead of the fast direct tunnel. |
+10 -1
View File
@@ -210,8 +210,17 @@ json
|------|---------------|-------| |------|---------------|-------|
| `google_ip` | `216.239.38.120` | IP مورد استفاده برای مسیر Google | | `google_ip` | `216.239.38.120` | IP مورد استفاده برای مسیر Google |
| `front_domain` | `www.google.com` | دامنه‌ای که فیلتر می‌بیند | | `front_domain` | `www.google.com` | دامنه‌ای که فیلتر می‌بیند |
| `verify_ssl` | `true` | بررسی اعتبار TLS | | `verify_ssl` | `true` | بررسی اعتبار TLS فقط برای اتصال fronted محلی به Google/CDN |
| `relay_timeout` | `25` | مهلت کل برای هر درخواست relay قبل از fail شدن |
| `tls_connect_timeout` | `15` | مهلت اتصال TLS پروکسی به endpoint fronted روی Google/CDN |
| `tcp_connect_timeout` | `10` | مهلت اتصال برای tunnel مستقیم و SNI-rewrite |
| `max_response_body_bytes` | `209715200` | سقف نهایی برای اندازه body هر پاسخ relay بعد از buffer/decode |
| `script_ids` | - | چند Deployment ID برای load balancing | | `script_ids` | - | چند Deployment ID برای load balancing |
| `chunked_download_extensions` | مطابق [config.example.json](config.example.json) | پسوند فایل‌هایی که باید از دانلود range-parallel استفاده کنند. از `".*"` هم برای probe همه دانلودهای GET پشتیبانی می‌شود. |
| `chunked_download_min_size` | `5242880` | حداقل اندازه کل فایل (۵ مگابایت) برای فعال ماندن دانلود موازی |
| `chunked_download_chunk_size` | `524288` | اندازه هر chunk در دانلود موازی |
| `chunked_download_max_parallel` | `8` | حداکثر تعداد range request همزمان برای یک دانلود |
| `chunked_download_max_chunks` | `256` | سقف نرم برای تعداد کل chunk request ها؛ برای فایل‌های خیلی بزرگ اندازه chunk به‌صورت خودکار بیشتر می‌شود |
| `block_hosts` | `[]` | هاست‌هایی که هرگز نباید tunnel شوند (پاسخ 403). نام دقیق (`ads.example.com`) یا پسوند با نقطه‌ی ابتدایی (`.doubleclick.net`). | | `block_hosts` | `[]` | هاست‌هایی که هرگز نباید tunnel شوند (پاسخ 403). نام دقیق (`ads.example.com`) یا پسوند با نقطه‌ی ابتدایی (`.doubleclick.net`). |
| `bypass_hosts` | `["localhost", ".local", ".lan", ".home.arpa"]` | هاست‌هایی که مستقیم می‌روند (بدون MITM و بدون رله). برای منابع داخلی شبکه یا سایت‌هایی که با MITM مشکل دارند. | | `bypass_hosts` | `["localhost", ".local", ".lan", ".home.arpa"]` | هاست‌هایی که مستقیم می‌روند (بدون MITM و بدون رله). برای منابع داخلی شبکه یا سایت‌هایی که با MITM مشکل دارند. |
| `direct_google_exclude` | مراجعه به [config.example.json](config.example.json) | اپ‌های Google که باید از مسیر MITM برای رله استفاده کنند به‌جای tunnel مستقیم. | | `direct_google_exclude` | مراجعه به [config.example.json](config.example.json) | اپ‌های Google که باید از مسیر MITM برای رله استفاده کنند به‌جای tunnel مستقیم. |
+42 -1
View File
@@ -10,8 +10,49 @@
"socks5_port": 1080, "socks5_port": 1080,
"log_level": "INFO", "log_level": "INFO",
"verify_ssl": true, "verify_ssl": true,
"lan_sharing": true, "lan_sharing": false,
"relay_timeout": 25,
"tls_connect_timeout": 15,
"tcp_connect_timeout": 10,
"max_response_body_bytes": 209715200,
"parallel_relay": 1, "parallel_relay": 1,
"chunked_download_extensions": [
".bin",
".zip",
".tar",
".gz",
".bz2",
".xz",
".7z",
".rar",
".exe",
".msi",
".dmg",
".deb",
".rpm",
".apk",
".iso",
".img",
".mp4",
".mkv",
".avi",
".mov",
".webm",
".mp3",
".flac",
".wav",
".aac",
".pdf",
".doc",
".docx",
".ppt",
".pptx",
".wasm"
],
"chunked_download_min_size": 5242880,
"chunked_download_chunk_size": 524288,
"chunked_download_max_parallel": 8,
"chunked_download_max_chunks": 256,
"block_hosts": [], "block_hosts": [],
"bypass_hosts": [ "bypass_hosts": [
"localhost", "localhost",
+1 -1
View File
@@ -230,7 +230,7 @@ def main():
log.info("MITM CA is already trusted.") log.info("MITM CA is already trusted.")
# ── LAN sharing configuration ──────────────────────────────────────── # ── LAN sharing configuration ────────────────────────────────────────
lan_sharing = config.get("lan_sharing", True) lan_sharing = config.get("lan_sharing", False)
if lan_sharing: if lan_sharing:
# If LAN sharing is enabled and host is still localhost, change to all interfaces # If LAN sharing is enabled and host is still localhost, change to all interfaces
if config.get("listen_host", "127.0.0.1") == "127.0.0.1": if config.get("listen_host", "127.0.0.1") == "127.0.0.1":
+18 -1
View File
@@ -86,6 +86,15 @@ def load_base_config() -> dict:
"socks5_port": 1080, "socks5_port": 1080,
"log_level": "INFO", "log_level": "INFO",
"verify_ssl": True, "verify_ssl": True,
"lan_sharing": False,
"relay_timeout": 25,
"tls_connect_timeout": 15,
"tcp_connect_timeout": 10,
"max_response_body_bytes": 200 * 1024 * 1024,
"chunked_download_min_size": 5 * 1024 * 1024,
"chunked_download_chunk_size": 512 * 1024,
"chunked_download_max_parallel": 8,
"chunked_download_max_chunks": 256,
"hosts": {}, "hosts": {},
} }
@@ -118,7 +127,15 @@ def configure_apps_script(cfg: dict) -> dict:
def configure_network(cfg: dict) -> dict: def configure_network(cfg: dict) -> dict:
print() print()
print(bold("Network settings") + dim(" (press enter to accept defaults)")) print(bold("Network settings") + dim(" (press enter to accept defaults)"))
cfg["listen_host"] = prompt("Listen host", default=str(cfg.get("listen_host", "127.0.0.1"))) cfg["lan_sharing"] = prompt_yes_no(
"Enable LAN sharing?",
default=bool(cfg.get("lan_sharing", False)),
)
default_host = str(cfg.get("listen_host", "127.0.0.1"))
if cfg["lan_sharing"] and default_host == "127.0.0.1":
default_host = "0.0.0.0"
cfg["listen_host"] = prompt("Listen host", default=default_host)
port = prompt("HTTP proxy port", default=str(cfg.get("listen_port", 8085))) port = prompt("HTTP proxy port", default=str(cfg.get("listen_port", 8085)))
try: try:
+1
View File
@@ -165,6 +165,7 @@ STATIC_EXTS: tuple[str, ...] = (
".mp3", ".mp4", ".webm", ".wasm", ".avif", ".mp3", ".mp4", ".webm", ".wasm", ".avif",
) )
LARGE_FILE_EXTS = frozenset({ LARGE_FILE_EXTS = frozenset({
".bin",
".zip", ".tar", ".gz", ".bz2", ".xz", ".7z", ".rar", ".zip", ".tar", ".gz", ".bz2", ".xz", ".7z", ".rar",
".exe", ".msi", ".dmg", ".deb", ".rpm", ".apk", ".exe", ".msi", ".dmg", ".deb", ".rpm", ".apk",
".iso", ".img", ".iso", ".img",
+649 -57
View File
@@ -16,6 +16,7 @@ import logging
import re import re
import socket import socket
import ssl import ssl
import tempfile
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from urllib.parse import urlparse from urllib.parse import urlparse
@@ -27,6 +28,7 @@ from constants import (
BATCH_WINDOW_MICRO, BATCH_WINDOW_MICRO,
CONN_TTL, CONN_TTL,
FRONT_SNI_POOL_GOOGLE, FRONT_SNI_POOL_GOOGLE,
MAX_RESPONSE_BODY_BYTES,
POOL_MAX, POOL_MAX,
POOL_MIN_IDLE, POOL_MIN_IDLE,
RELAY_TIMEOUT, RELAY_TIMEOUT,
@@ -83,6 +85,18 @@ def _build_sni_pool(front_domain: str, overrides: list | None) -> list[str]:
class DomainFronter: class DomainFronter:
_STATIC_EXTS = STATIC_EXTS _STATIC_EXTS = STATIC_EXTS
_H2_FAILURE_COOLDOWN = 60.0
_H2_FAILURE_THRESHOLD = 3
_DOWNLOAD_STREAM_COOLDOWN = 300.0
_COALESCE_VARY_HEADERS = (
"accept",
"accept-language",
"user-agent",
"sec-fetch-dest",
"sec-fetch-mode",
"sec-fetch-site",
)
_SAFE_RETRY_METHODS = {"GET", "HEAD", "OPTIONS"}
def __init__(self, config: dict): def __init__(self, config: dict):
self.connect_host = config.get("google_ip", "216.239.38.120") self.connect_host = config.get("google_ip", "216.239.38.120")
@@ -120,6 +134,16 @@ class DomainFronter:
self.auth_key = config.get("auth_key", "") self.auth_key = config.get("auth_key", "")
self.verify_ssl = config.get("verify_ssl", True) self.verify_ssl = config.get("verify_ssl", True)
self._relay_timeout = self._cfg_float(
config, "relay_timeout", RELAY_TIMEOUT, minimum=1.0,
)
self._tls_connect_timeout = self._cfg_float(
config, "tls_connect_timeout", TLS_CONNECT_TIMEOUT, minimum=1.0,
)
self._max_response_body_bytes = self._cfg_int(
config, "max_response_body_bytes", MAX_RESPONSE_BODY_BYTES,
minimum=1024,
)
# Connection pool — TTL-based, pre-warmed, with concurrency control # Connection pool — TTL-based, pre-warmed, with concurrency control
self._pool: list[tuple[asyncio.StreamReader, asyncio.StreamWriter, float]] = [] self._pool: list[tuple[asyncio.StreamReader, asyncio.StreamWriter, float]] = []
@@ -146,6 +170,9 @@ class DomainFronter:
# Request coalescing — dedup concurrent identical GETs # Request coalescing — dedup concurrent identical GETs
self._coalesce: dict[str, list[asyncio.Future]] = {} self._coalesce: dict[str, list[asyncio.Future]] = {}
self._h2_failure_streak = 0
self._h2_disabled_until = 0.0
self._stream_download_disabled_until: dict[str, float] = {}
# HTTP/2 multiplexing — one connection handles all requests # HTTP/2 multiplexing — one connection handles all requests
self._h2 = None self._h2 = None
@@ -173,6 +200,23 @@ class DomainFronter:
# ── helpers ─────────────────────────────────────────────────── # ── helpers ───────────────────────────────────────────────────
@staticmethod
def _cfg_int(config: dict, key: str, default: int, *, minimum: int = 1) -> int:
try:
value = int(config.get(key, default))
except (TypeError, ValueError):
value = default
return max(minimum, value)
@staticmethod
def _cfg_float(config: dict, key: str, default: float,
*, minimum: float = 0.1) -> float:
try:
value = float(config.get(key, default))
except (TypeError, ValueError):
value = default
return max(minimum, value)
def _ssl_ctx(self) -> ssl.SSLContext: def _ssl_ctx(self) -> ssl.SSLContext:
ctx = ssl.create_default_context() ctx = ssl.create_default_context()
if not self.verify_ssl: if not self.verify_ssl:
@@ -180,6 +224,54 @@ class DomainFronter:
ctx.verify_mode = ssl.CERT_NONE ctx.verify_mode = ssl.CERT_NONE
return ctx return ctx
def _h2_available(self) -> bool:
return (
self._h2 is not None
and self._h2.is_connected
and time.time() >= self._h2_disabled_until
)
def _record_h2_success(self) -> None:
self._h2_failure_streak = 0
def _record_h2_failure(self, exc: Exception) -> None:
self._h2_failure_streak += 1
if self._h2_failure_streak >= self._H2_FAILURE_THRESHOLD:
self._h2_disabled_until = time.time() + self._H2_FAILURE_COOLDOWN
log.warning(
"H2 temporarily disabled for %.0fs after %d consecutive failures (%s)",
self._H2_FAILURE_COOLDOWN,
self._h2_failure_streak,
type(exc).__name__,
)
self._h2_failure_streak = 0
def _stream_download_allowed(self, url: str) -> bool:
host = self._host_key(url)
if not host:
return True
until = self._stream_download_disabled_until.get(host, 0.0)
if until > time.time():
return False
if until:
self._stream_download_disabled_until.pop(host, None)
return True
def _mark_stream_download_failure(self, url: str, reason: str) -> None:
host = self._host_key(url)
if not host:
return
self._stream_download_disabled_until[host] = (
time.time() + self._DOWNLOAD_STREAM_COOLDOWN
)
log.warning(
"Parallel streaming disabled for host %s for %.0fs after failure (%s)",
host, self._DOWNLOAD_STREAM_COOLDOWN, reason,
)
def stream_download_allowed(self, url: str) -> bool:
return self._stream_download_allowed(url)
async def _open(self): async def _open(self):
"""Open a TLS connection to the CDN. """Open a TLS connection to the CDN.
@@ -228,7 +320,7 @@ class DomainFronter:
except Exception: except Exception:
pass pass
reader, writer = await asyncio.wait_for( reader, writer = await asyncio.wait_for(
self._open(), timeout=TLS_CONNECT_TIMEOUT self._open(), timeout=self._tls_connect_timeout
) )
# Pool was empty — trigger aggressive background refill # Pool was empty — trigger aggressive background refill
if not self._refilling: if not self._refilling:
@@ -327,6 +419,171 @@ class DomainFronter:
host = parsed.hostname or url_or_host host = parsed.hostname or url_or_host
return host.lower().rstrip(".") return host.lower().rstrip(".")
@classmethod
def _coalesce_key(cls, url: str, headers: dict | None) -> str:
key = [url]
if headers:
lowered = {str(k).lower(): str(v) for k, v in headers.items()}
for name in cls._COALESCE_VARY_HEADERS:
value = lowered.get(name)
if value:
key.append(f"{name}={value}")
return "\n".join(key)
@classmethod
def _retry_attempts_for_payload(cls, payload: dict) -> int:
method = str(payload.get("m", "GET")).upper()
return 2 if method in cls._SAFE_RETRY_METHODS else 1
@staticmethod
def _render_streaming_headers(resp_headers: dict, total_size: int) -> bytes:
lines = ["HTTP/1.1 200 OK"]
skip = {
"transfer-encoding",
"connection",
"keep-alive",
"content-length",
"content-range",
}
for key, value in resp_headers.items():
if key.lower() in skip:
continue
lines.append(f"{key}: {value}")
lines.append(f"Content-Length: {total_size}")
lines.append("")
lines.append("")
return "\r\n".join(lines).encode()
@staticmethod
def _parse_content_range(value: str) -> tuple[int, int, int] | None:
match = re.match(r"^\s*bytes\s+(\d+)-(\d+)/(\d+)\s*$", value or "")
if not match:
return None
start, end, total = (int(group) for group in match.groups())
if start < 0 or end < start or total <= end:
return None
return start, end, total
@classmethod
def _validate_range_response(cls, status: int, resp_headers: dict,
body: bytes, start_off: int,
end_off: int,
total_size: int | None = None) -> str | None:
if status != 206:
return f"status {status}"
parsed = cls._parse_content_range(resp_headers.get("content-range", ""))
if not parsed:
return "missing/invalid Content-Range"
got_start, got_end, got_total = parsed
if got_start != start_off or got_end != end_off:
return f"Content-Range mismatch {got_start}-{got_end}"
if total_size is not None and got_total != total_size:
return f"Content-Range total mismatch {got_total}/{total_size}"
expected = end_off - start_off + 1
if len(body) != expected:
return f"short chunk {len(body)}/{expected} B"
return None
@staticmethod
def _spool_write(file_obj, offset: int, data: bytes) -> None:
file_obj.seek(offset)
file_obj.write(data)
file_obj.flush()
@staticmethod
def _spool_read(file_obj, offset: int, size: int) -> bytes:
file_obj.seek(offset)
return file_obj.read(size)
@staticmethod
def _format_bytes_human(num_bytes: int) -> str:
value = float(max(0, num_bytes))
units = ("B", "KiB", "MiB", "GiB", "TiB")
unit = units[0]
for unit in units:
if value < 1024.0 or unit == units[-1]:
break
value /= 1024.0
if unit == "B":
return f"{int(value)} {unit}"
return f"{value:.1f} {unit}"
@staticmethod
def _format_elapsed_short(seconds: float) -> str:
total = max(0, int(seconds))
minutes, secs = divmod(total, 60)
hours, minutes = divmod(minutes, 60)
if hours:
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
return f"{minutes:02d}:{secs:02d}"
@staticmethod
def _render_progress_bar(done: int, total: int, width: int = 34) -> str:
if total <= 0:
return "[" + ("-" * width) + "]"
ratio = max(0.0, min(1.0, done / total))
filled = min(width, int(round(ratio * width)))
return "[" + ("#" * filled) + ("-" * (width - filled)) + "]"
@classmethod
def _progress_line(cls, *, elapsed: float, done: int, total: int,
speed_bytes_per_sec: float) -> str:
return (
f"[{cls._format_elapsed_short(elapsed)}] "
f"{cls._render_progress_bar(done, total)} "
f"{cls._format_bytes_human(done)} / {cls._format_bytes_human(total)} "
f"({cls._format_bytes_human(int(speed_bytes_per_sec))}/s)"
)
async def _relay_payload_h1(self, payload: dict) -> bytes:
attempts = self._retry_attempts_for_payload(payload)
async with self._semaphore:
for attempt in range(attempts):
try:
return await asyncio.wait_for(
self._relay_single(payload), timeout=self._relay_timeout,
)
except Exception as exc:
if attempt < attempts - 1:
log.debug(
"H1 relay attempt %d failed (%s: %s), retrying",
attempt + 1, type(exc).__name__, exc,
)
await self._flush_pool()
else:
raise
async def _range_probe(self, url: str, headers: dict, start_off: int,
end_off: int, *, max_tries: int = 3) -> bytes:
probe_headers = dict(headers) if headers else {}
probe_headers["Range"] = f"bytes={start_off}-{end_off}"
probe_payload = self._build_payload("GET", url, probe_headers, b"")
last_raw = b""
last_status = 0
for attempt in range(max_tries):
try:
last_raw = await self._relay_payload_h1(probe_payload)
except Exception as exc:
if attempt == max_tries - 1:
raise
log.warning(
"Initial range probe %d-%d retry %d/%d failed: %r",
start_off, end_off, attempt + 1, max_tries, exc,
)
await asyncio.sleep(0.3 * (attempt + 1))
continue
last_status, _, _ = self._split_raw_response(last_raw)
if last_status == 206 or last_status < 500:
return last_raw
if attempt < max_tries - 1:
log.warning(
"Initial range probe %d-%d retry %d/%d: status %d",
start_off, end_off, attempt + 1, max_tries, last_status,
)
await asyncio.sleep(0.3 * (attempt + 1))
return last_raw
# ── Per-host stats ──────────────────────────────────────────── # ── Per-host stats ────────────────────────────────────────────
def _record_site(self, url: str, bytes_: int, latency_ns: int, def _record_site(self, url: str, bytes_: int, latency_ns: int,
@@ -521,12 +778,17 @@ class DomainFronter:
async def close(self): async def close(self):
"""Cancel background tasks and close all pooled / H2 connections.""" """Cancel background tasks and close all pooled / H2 connections."""
for task in list(self._bg_tasks): tasks = list(self._bg_tasks)
for task in tasks:
task.cancel() task.cancel()
if self._bg_tasks: if tasks:
self._spawn(self._prewarm_script()) await asyncio.gather(*tasks, return_exceptions=True)
if self._keepalive_task is None or self._keepalive_task.done(): self._bg_tasks.clear()
self._keepalive_task = self._spawn
self._warm_task = None
self._maintenance_task = None
self._stats_task = None
self._keepalive_task = None
await self._flush_pool() await self._flush_pool()
@@ -538,18 +800,25 @@ class DomainFronter:
async def _h2_connect(self): async def _h2_connect(self):
"""Connect the HTTP/2 transport in background.""" """Connect the HTTP/2 transport in background."""
if self._h2 is None:
return
if time.time() < self._h2_disabled_until:
return
try: try:
await self._h2.ensure_connected() await self._h2.ensure_connected()
self._record_h2_success()
log.info("H2 multiplexing active — one conn handles all requests") log.info("H2 multiplexing active — one conn handles all requests")
except Exception as e: except Exception as e:
self._record_h2_failure(e)
log.warning("H2 connect failed (%s), using H1 pool fallback", e) log.warning("H2 connect failed (%s), using H1 pool fallback", e)
async def _h2_connect_and_warm(self): async def _h2_connect_and_warm(self):
"""Connect H2, pre-warm the Apps Script container, start keepalive.""" """Connect H2, pre-warm the Apps Script container, start keepalive."""
await self._h2_connect() await self._h2_connect()
if self._h2 and self._h2.is_connected: if self._h2_available():
asyncio.create_task(self._prewarm_script()) self._spawn(self._prewarm_script())
asyncio.create_task(self._keepalive_loop()) if self._keepalive_task is None or self._keepalive_task.done():
self._keepalive_task = self._spawn(self._keepalive_loop())
async def _prewarm_script(self): async def _prewarm_script(self):
"""Pre-warm Apps Script and detect /dev fast path (no redirect).""" """Pre-warm Apps Script and detect /dev fast path (no redirect)."""
@@ -602,10 +871,12 @@ class DomainFronter:
try: try:
await asyncio.sleep(240) # 4 minutes — saves ~90 quota hits/day vs 180s await asyncio.sleep(240) # 4 minutes — saves ~90 quota hits/day vs 180s
# Google's container timeout is ~5 min idle # Google's container timeout is ~5 min idle
if not self._h2 or not self._h2.is_connected: if not self._h2_available():
try: try:
await self._h2.reconnect() await self._h2.reconnect()
except Exception: self._record_h2_success()
except Exception as exc:
self._record_h2_failure(exc)
continue continue
# H2 PING to keep connection alive # H2 PING to keep connection alive
@@ -681,7 +952,9 @@ class DomainFronter:
has_range = True has_range = True
break break
if method == "GET" and not body and not has_range: if method == "GET" and not body and not has_range:
result = await self._coalesced_submit(url, payload) result = await self._coalesced_submit(
self._coalesce_key(url, headers), payload,
)
return result return result
result = await self._batch_submit(payload) result = await self._batch_submit(payload)
@@ -693,7 +966,7 @@ class DomainFronter:
latency_ns = int((time.perf_counter() - t0) * 1e9) latency_ns = int((time.perf_counter() - t0) * 1e9)
self._record_site(url, len(result), latency_ns, errored) self._record_site(url, len(result), latency_ns, errored)
async def _coalesced_submit(self, url: str, payload: dict) -> bytes: async def _coalesced_submit(self, key: str, payload: dict) -> bytes:
"""Dedup concurrent requests for the same URL (no Range header). """Dedup concurrent requests for the same URL (no Range header).
Uses `_batch_lock` to atomically check-and-append, preventing a Uses `_batch_lock` to atomically check-and-append, preventing a
@@ -702,14 +975,14 @@ class DomainFronter:
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
async with self._batch_lock: async with self._batch_lock:
waiters = self._coalesce.get(url) waiters = self._coalesce.get(key)
if waiters is not None: if waiters is not None:
future = loop.create_future() future = loop.create_future()
waiters.append(future) waiters.append(future)
log.debug("Coalesced request: %s", url[:60]) log.debug("Coalesced request: %s", key.split("\n", 1)[0][:60])
waiting = True waiting = True
else: else:
self._coalesce[url] = [] self._coalesce[key] = []
waiting = False waiting = False
if waiting: if waiting:
@@ -719,14 +992,14 @@ class DomainFronter:
result = await self._batch_submit(payload) result = await self._batch_submit(payload)
except Exception as e: except Exception as e:
async with self._batch_lock: async with self._batch_lock:
waiters = self._coalesce.pop(url, []) waiters = self._coalesce.pop(key, [])
for f in waiters: for f in waiters:
if not f.done(): if not f.done():
f.set_exception(e) f.set_exception(e)
raise raise
async with self._batch_lock: async with self._batch_lock:
waiters = self._coalesce.pop(url, []) waiters = self._coalesce.pop(key, [])
for f in waiters: for f in waiters:
if not f.done(): if not f.done():
f.set_result(result) f.set_result(result)
@@ -734,8 +1007,10 @@ class DomainFronter:
async def relay_parallel(self, method: str, url: str, async def relay_parallel(self, method: str, url: str,
headers: dict, body: bytes = b"", headers: dict, body: bytes = b"",
chunk_size: int = 256 * 1024, chunk_size: int = 512 * 1024,
max_parallel: int = 16) -> bytes: max_parallel: int = 8,
max_chunks: int = 256,
min_size: int = 0) -> bytes:
"""Relay with parallel range acceleration for large downloads. """Relay with parallel range acceleration for large downloads.
Strategy: Strategy:
@@ -747,17 +1022,15 @@ class DomainFronter:
Since each Apps Script call takes ~2s regardless of payload size, Since each Apps Script call takes ~2s regardless of payload size,
we use: we use:
- 256 KB chunks (safe under Apps Script response limit) - 512 KB chunks (fewer relay calls, lower quota pressure)
- Up to 16 chunks in flight at once via H2 multiplexing - Up to 8 chunks in flight at once via H2 multiplexing
- Aggregate throughput of ~2 MB per round-trip (~2-3s) - Aggregate throughput of ~2 MB per round-trip (~2-3s)
""" """
if method != "GET" or body: if method != "GET" or body:
return await self.relay(method, url, headers, body) return await self.relay(method, url, headers, body)
# Probe: first chunk with Range header # Probe: first chunk with Range header
range_headers = dict(headers) if headers else {} first_resp = await self._range_probe(url, headers, 0, chunk_size - 1)
range_headers["Range"] = f"bytes=0-{chunk_size - 1}"
first_resp = await self.relay("GET", url, range_headers, b"")
status, resp_hdrs, resp_body = self._split_raw_response(first_resp) status, resp_hdrs, resp_body = self._split_raw_response(first_resp)
@@ -768,13 +1041,40 @@ class DomainFronter:
return first_resp return first_resp
# Parse total size from Content-Range: "bytes 0-262143/1048576" # Parse total size from Content-Range: "bytes 0-262143/1048576"
content_range = resp_hdrs.get("content-range", "") parsed_range = self._parse_content_range(resp_hdrs.get("content-range", ""))
m = re.search(r"/(\d+)", content_range) if not parsed_range:
if not m:
# Can't parse — downgrade to 200 so the client (which sent a # Can't parse — downgrade to 200 so the client (which sent a
# plain GET) doesn't get confused by 206 + Content-Range. # plain GET) doesn't get confused by 206 + Content-Range.
return self._rewrite_206_to_200(first_resp) return self._rewrite_206_to_200(first_resp)
total_size = int(m.group(1)) first_start, first_end, total_size = parsed_range
first_err = self._validate_range_response(
status, resp_hdrs, resp_body, first_start, first_end, total_size,
)
if first_start != 0 or first_err:
return self._rewrite_206_to_200(first_resp)
if total_size > self._max_response_body_bytes:
return self._error_response(
502,
"Relay response exceeds cap "
f"({self._max_response_body_bytes} bytes). "
"Increase max_response_body_bytes if your system has enough RAM.",
)
if min_size > 0 and total_size < min_size:
return self._rewrite_206_to_200(first_resp)
if max_chunks > 0:
required_chunk_size = max(
chunk_size,
(total_size + max_chunks - 1) // max_chunks,
)
if required_chunk_size != chunk_size:
log.info(
"Parallel download tuning: chunk size raised from %d KB to %d KB "
"to keep request count under %d",
chunk_size // 1024,
required_chunk_size // 1024,
max_chunks,
)
chunk_size = required_chunk_size
# Small file: probe already fetched it all. MUST rewrite to 200 # Small file: probe already fetched it all. MUST rewrite to 200
# because the client never sent a Range header — a stray 206 here # because the client never sent a Range header — a stray 206 here
@@ -795,22 +1095,54 @@ class DomainFronter:
# Concurrency-limited parallel fetch # Concurrency-limited parallel fetch
sem = asyncio.Semaphore(max_parallel) sem = asyncio.Semaphore(max_parallel)
progress_lock = asyncio.Lock()
completed_chunks = 1 # first range probe already succeeded
completed_bytes = len(resp_body)
last_progress_log = time.perf_counter()
total_chunks = len(ranges) + 1
total_bytes = total_size
async def fetch_range(s, e, max_tries: int = 3): async def fetch_range(s, e, max_tries: int = 3):
nonlocal completed_chunks, completed_bytes, last_progress_log
async with sem: async with sem:
rh_base = dict(headers) if headers else {} rh_base = dict(headers) if headers else {}
rh_base["Range"] = f"bytes={s}-{e}" rh_base["Range"] = f"bytes={s}-{e}"
payload = self._build_payload("GET", url, rh_base, b"")
expected = e - s + 1 expected = e - s + 1
last_err = None last_err = None
for attempt in range(max_tries): for attempt in range(max_tries):
try: try:
raw = await self.relay("GET", url, rh_base, b"") raw = await self._relay_payload_h1(payload)
_, _, chunk_body = self._split_raw_response(raw) chunk_status, chunk_headers, chunk_body = self._split_raw_response(raw)
if len(chunk_body) == expected: err = self._validate_range_response(
return chunk_body chunk_status, chunk_headers, chunk_body,
last_err = ( s, e, total_size,
f"short chunk {len(chunk_body)}/{expected} B"
) )
if err is None:
now = time.perf_counter()
async with progress_lock:
completed_chunks += 1
completed_bytes += len(chunk_body)
should_log = (
completed_chunks == total_chunks
or (now - last_progress_log) >= 5.0
)
if should_log:
elapsed = max(0.001, now - t0)
speed_bps = completed_bytes / elapsed
log.info(
"Parallel download progress: %s [%d/%d chunks]",
self._progress_line(
elapsed=elapsed,
done=completed_bytes,
total=total_bytes,
speed_bytes_per_sec=speed_bps,
),
completed_chunks, total_chunks,
)
last_progress_log = now
return chunk_body
last_err = err
except Exception as e_: except Exception as e_:
last_err = repr(e_) last_err = repr(e_)
log.warning("Range %d-%d retry %d/%d: %s", log.warning("Range %d-%d retry %d/%d: %s",
@@ -837,8 +1169,15 @@ class DomainFronter:
full_body = b"".join(parts) full_body = b"".join(parts)
kbs = (len(full_body) / 1024) / elapsed if elapsed > 0 else 0 kbs = (len(full_body) / 1024) / elapsed if elapsed > 0 else 0
log.info("Parallel download complete: %d B in %.2fs = %.1f KB/s", log.info(
len(full_body), elapsed, kbs) "Parallel download complete: %s",
self._progress_line(
elapsed=elapsed,
done=len(full_body),
total=len(full_body),
speed_bytes_per_sec=kbs * 1024,
),
)
# Return as 200 OK (client sent a normal GET) # Return as 200 OK (client sent a normal GET)
result = f"HTTP/1.1 200 OK\r\n" result = f"HTTP/1.1 200 OK\r\n"
@@ -851,6 +1190,219 @@ class DomainFronter:
result += "\r\n" result += "\r\n"
return result.encode() + full_body return result.encode() + full_body
async def stream_parallel_download(self, url: str, headers: dict,
writer,
*,
chunk_size: int = 512 * 1024,
max_parallel: int = 8,
max_chunks: int = 256,
min_size: int = 0) -> bool:
"""Stream a large range-capable download to the client incrementally.
Returns False when the target should fall back to the normal relay
path (for example no range support or the file is too small).
Returns True once this method has taken ownership of the client
response, even if the stream later aborts.
"""
first_resp = await self._range_probe(url, headers, 0, chunk_size - 1)
status, resp_hdrs, resp_body = self._split_raw_response(first_resp)
if status != 206:
log.info(
"Streaming download fallback: initial probe returned %s for %s",
status, url[:80],
)
return False
parsed_range = self._parse_content_range(resp_hdrs.get("content-range", ""))
if not parsed_range:
log.info(
"Streaming download fallback: missing/invalid Content-Range for %s",
url[:80],
)
return False
first_start, first_end, total_size = parsed_range
first_err = self._validate_range_response(
status, resp_hdrs, resp_body, first_start, first_end, total_size,
)
if first_start != 0 or first_err:
log.info(
"Streaming download fallback: invalid first range (%s) for %s",
first_err or f"start={first_start}",
url[:80],
)
return False
if min_size > 0 and total_size < min_size:
log.info(
"Streaming download fallback: file too small (%d < %d) for %s",
total_size, min_size, url[:80],
)
return False
if max_chunks > 0:
required_chunk_size = max(
chunk_size,
(total_size + max_chunks - 1) // max_chunks,
)
if required_chunk_size != chunk_size:
log.info(
"Parallel download tuning: chunk size raised from %d KB to %d KB "
"to keep request count under %d",
chunk_size // 1024,
required_chunk_size // 1024,
max_chunks,
)
chunk_size = required_chunk_size
if total_size <= chunk_size or len(resp_body) >= total_size:
writer.write(self._render_streaming_headers(resp_hdrs, total_size))
writer.write(resp_body)
await writer.drain()
return True
ranges = []
start = len(resp_body)
while start < total_size:
end = min(start + chunk_size - 1, total_size - 1)
ranges.append((start, end))
start = end + 1
log.info("Parallel streaming download: %d bytes, %d chunks of %d KB",
total_size, len(ranges) + 1, chunk_size // 1024)
temp_file = tempfile.TemporaryFile(prefix="mhrvpn_dl_")
file_lock = asyncio.Lock()
sem = asyncio.Semaphore(max_parallel)
cancel_event = asyncio.Event()
tasks: list[asyncio.Task] = []
ready = [asyncio.Event() for _ in ranges]
errors: list[Exception | None] = [None for _ in ranges]
delivered_chunks = 1
delivered_bytes = len(resp_body)
total_chunks = len(ranges) + 1
last_progress_log = time.perf_counter()
t0 = time.perf_counter()
async def _write_progress(force: bool = False) -> None:
nonlocal last_progress_log
now = time.perf_counter()
if not force and (now - last_progress_log) < 5.0:
return
elapsed = max(0.001, now - t0)
speed_bps = delivered_bytes / elapsed
log.info(
"Parallel download progress: %s [%d/%d chunks]",
self._progress_line(
elapsed=elapsed,
done=delivered_bytes,
total=total_size,
speed_bytes_per_sec=speed_bps,
),
delivered_chunks, total_chunks,
)
last_progress_log = now
async def fetch_range(index: int, start_off: int, end_off: int,
max_tries: int = 3) -> None:
async with sem:
base_headers = dict(headers) if headers else {}
base_headers["Range"] = f"bytes={start_off}-{end_off}"
payload = self._build_payload("GET", url, base_headers, b"")
expected = end_off - start_off + 1
last_err = "unknown"
try:
for attempt in range(max_tries):
if cancel_event.is_set():
return
try:
raw = await self._relay_payload_h1(payload)
chunk_status, chunk_headers, chunk_body = self._split_raw_response(raw)
err = self._validate_range_response(
chunk_status, chunk_headers, chunk_body,
start_off, end_off, total_size,
)
if err is None:
async with file_lock:
await asyncio.to_thread(
self._spool_write, temp_file, start_off, chunk_body,
)
ready[index].set()
return
last_err = err
except Exception as exc:
last_err = repr(exc)
if cancel_event.is_set():
return
log.warning("Range %d-%d retry %d/%d: %s",
start_off, end_off, attempt + 1, max_tries, last_err)
await asyncio.sleep(0.3 * (attempt + 1))
errors[index] = RuntimeError(
f"chunk {start_off}-{end_off} failed after {max_tries} tries: {last_err}"
)
ready[index].set()
except asyncio.CancelledError:
raise
try:
writer.write(self._render_streaming_headers(resp_hdrs, total_size))
writer.write(resp_body)
await writer.drain()
for index, (start_off, end_off) in enumerate(ranges):
tasks.append(asyncio.create_task(fetch_range(index, start_off, end_off)))
for index, (start_off, end_off) in enumerate(ranges):
await ready[index].wait()
if errors[index] is not None:
raise errors[index]
expected = end_off - start_off + 1
async with file_lock:
chunk = await asyncio.to_thread(
self._spool_read, temp_file, start_off, expected,
)
if len(chunk) != expected:
raise RuntimeError(
f"spooled chunk {start_off}-{end_off} was truncated "
f"({len(chunk)}/{expected} B)"
)
writer.write(chunk)
await writer.drain()
delivered_chunks += 1
delivered_bytes += len(chunk)
await _write_progress(force=(index == len(ranges) - 1))
elapsed = max(0.001, time.perf_counter() - t0)
log.info(
"Parallel streaming download complete: %s",
self._progress_line(
elapsed=elapsed,
done=total_size,
total=total_size,
speed_bytes_per_sec=total_size / elapsed,
),
)
return True
except (ConnectionError, BrokenPipeError, TimeoutError) as exc:
log.info("Parallel download cancelled by client: %s", exc)
cancel_event.set()
return True
except Exception as exc:
self._mark_stream_download_failure(url, str(exc))
log.error("Parallel streaming download failed (%s): %s", url[:60], exc)
cancel_event.set()
try:
if not writer.is_closing():
writer.close()
except Exception:
pass
return True
finally:
cancel_event.set()
for task in tasks:
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
temp_file.close()
@staticmethod @staticmethod
def _rewrite_206_to_200(raw: bytes) -> bytes: def _rewrite_206_to_200(raw: bytes) -> bytes:
"""Rewrite a 206 Partial Content response to 200 OK. """Rewrite a 206 Partial Content response to 200 OK.
@@ -1039,33 +1591,43 @@ class DomainFronter:
async def _relay_with_retry(self, payload: dict) -> bytes: async def _relay_with_retry(self, payload: dict) -> bytes:
"""Single relay with one retry on failure. Uses H2 if available.""" """Single relay with one retry on failure. Uses H2 if available."""
attempts = self._retry_attempts_for_payload(payload)
# Fan-out: race N Apps Script instances when enabled and H2 is up. # Fan-out: race N Apps Script instances when enabled and H2 is up.
# Cuts tail latency when one container is slow/cold. Only kicks in # Cuts tail latency when one container is slow/cold. Only kicks in
# if multiple script IDs are configured and the H2 transport is live. # if multiple script IDs are configured and the H2 transport is live.
if (self._parallel_relay > 1 if (attempts > 1
and self._parallel_relay > 1
and len(self._script_ids) > 1 and len(self._script_ids) > 1
and self._h2 and self._h2.is_connected): and self._h2_available()):
try: try:
return await asyncio.wait_for( result = await asyncio.wait_for(
self._relay_fanout(payload), timeout=RELAY_TIMEOUT, self._relay_fanout(payload), timeout=self._relay_timeout,
) )
self._record_h2_success()
return result
except Exception as e: except Exception as e:
self._record_h2_failure(e)
log.debug("Fan-out relay failed (%s), falling back", e) log.debug("Fan-out relay failed (%s), falling back", e)
# fall through to single-path logic below # fall through to single-path logic below
# Try HTTP/2 first — much faster (multiplexed, no pool checkout) # Try HTTP/2 first — much faster (multiplexed, no pool checkout)
if self._h2 and self._h2.is_connected: if self._h2_available():
for attempt in range(2): for attempt in range(attempts):
try: try:
return await asyncio.wait_for( result = await asyncio.wait_for(
self._relay_single_h2(payload), timeout=RELAY_TIMEOUT self._relay_single_h2(payload), timeout=self._relay_timeout
) )
self._record_h2_success()
return result
except Exception as e: except Exception as e:
if attempt == 0: self._record_h2_failure(e)
if attempt < attempts - 1:
log.debug("H2 relay failed (%s), reconnecting", e) log.debug("H2 relay failed (%s), reconnecting", e)
try: try:
await self._h2.reconnect() await self._h2.reconnect()
except Exception: self._record_h2_success()
except Exception as reconnect_exc:
self._record_h2_failure(reconnect_exc)
log.warning("H2 reconnect failed, falling back to H1") log.warning("H2 reconnect failed, falling back to H1")
break break
else: else:
@@ -1073,14 +1635,15 @@ class DomainFronter:
# HTTP/1.1 fallback (pool-based) # HTTP/1.1 fallback (pool-based)
async with self._semaphore: async with self._semaphore:
for attempt in range(2): for attempt in range(attempts):
try: try:
return await asyncio.wait_for( return await asyncio.wait_for(
self._relay_single(payload), timeout=RELAY_TIMEOUT self._relay_single(payload), timeout=self._relay_timeout
) )
except Exception as e: except Exception as e:
if attempt == 0: if attempt < attempts - 1:
log.debug("Relay attempt 1 failed (%s: %s), retrying", log.debug("Relay attempt %d failed (%s: %s), retrying",
attempt + 1,
type(e).__name__, e) type(e).__name__, e)
await self._flush_pool() await self._flush_pool()
else: else:
@@ -1248,7 +1811,7 @@ class DomainFronter:
path = self._exec_path(payloads[0].get("u") if payloads else None) path = self._exec_path(payloads[0].get("u") if payloads else None)
# Try HTTP/2 first # Try HTTP/2 first
if self._h2 and self._h2.is_connected: if self._h2_available():
try: try:
status, headers, body = await asyncio.wait_for( status, headers, body = await asyncio.wait_for(
self._h2.request( self._h2.request(
@@ -1258,8 +1821,10 @@ class DomainFronter:
), ),
timeout=30, timeout=30,
) )
self._record_h2_success()
return self._parse_batch_body(body, payloads) return self._parse_batch_body(body, payloads)
except Exception as e: except Exception as e:
self._record_h2_failure(e)
log.debug("H2 batch failed (%s), falling back to H1", e) log.debug("H2 batch failed (%s), falling back to H1", e)
# HTTP/1.1 fallback # HTTP/1.1 fallback
@@ -1383,7 +1948,13 @@ class DomainFronter:
if "chunked" in transfer_encoding: if "chunked" in transfer_encoding:
body = await self._read_chunked(reader, body) body = await self._read_chunked(reader, body)
elif content_length: elif content_length:
remaining = int(content_length) - len(body) total = int(content_length)
if total > self._max_response_body_bytes:
raise RuntimeError(
"Relay response exceeds configured size cap "
f"({total} > {self._max_response_body_bytes} bytes)"
)
remaining = total - len(body)
while remaining > 0: while remaining > 0:
chunk = await asyncio.wait_for( chunk = await asyncio.wait_for(
reader.read(min(remaining, 65536)), timeout=20 reader.read(min(remaining, 65536)), timeout=20
@@ -1391,6 +1962,10 @@ class DomainFronter:
if not chunk: if not chunk:
break break
body += chunk body += chunk
if len(body) > self._max_response_body_bytes:
raise RuntimeError(
"Relay response exceeded configured size cap while reading body"
)
remaining -= len(chunk) remaining -= len(chunk)
else: else:
# No framing — short timeout read (keep-alive safe) # No framing — short timeout read (keep-alive safe)
@@ -1400,6 +1975,10 @@ class DomainFronter:
if not chunk: if not chunk:
break break
body += chunk body += chunk
if len(body) > self._max_response_body_bytes:
raise RuntimeError(
"Relay response exceeded configured size cap while streaming"
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
break break
@@ -1407,13 +1986,17 @@ class DomainFronter:
enc = headers.get("content-encoding", "") enc = headers.get("content-encoding", "")
if enc: if enc:
body = codec.decode(body, enc) body = codec.decode(body, enc)
if len(body) > self._max_response_body_bytes:
raise RuntimeError(
"Decoded relay response exceeded configured size cap"
)
return status, headers, body return status, headers, body
async def _read_chunked(self, reader, buf=b""): async def _read_chunked(self, reader, buf=b""):
"""Incrementally read chunked transfer-encoding.""" """Incrementally read chunked transfer-encoding."""
result = b"" result = b""
_MAX_BODY = 200 * 1024 * 1024 # 200 MB total body cap max_body = self._max_response_body_bytes
while True: while True:
while b"\r\n" not in buf: while b"\r\n" not in buf:
data = await asyncio.wait_for(reader.read(8192), timeout=20) data = await asyncio.wait_for(reader.read(8192), timeout=20)
@@ -1433,9 +2016,11 @@ class DomainFronter:
break break
if size == 0: if size == 0:
break break
if size > _MAX_BODY or len(result) + size > _MAX_BODY: if size > max_body or len(result) + size > max_body:
log.warning("Chunked body exceeds %d MB cap — truncating", _MAX_BODY // (1024 * 1024)) raise RuntimeError(
break "Chunked relay response exceeded configured size cap "
f"({max_body} bytes)"
)
while len(buf) < size + 2: while len(buf) < size + 2:
data = await asyncio.wait_for(reader.read(65536), timeout=20) data = await asyncio.wait_for(reader.read(65536), timeout=20)
@@ -1479,6 +2064,13 @@ class DomainFronter:
status = data.get("s", 200) status = data.get("s", 200)
resp_headers = data.get("h", {}) resp_headers = data.get("h", {})
resp_body = base64.b64decode(data.get("b", "")) resp_body = base64.b64decode(data.get("b", ""))
if len(resp_body) > self._max_response_body_bytes:
return self._error_response(
502,
"Relay response exceeds cap "
f"({self._max_response_body_bytes} bytes). "
"Increase max_response_body_bytes if your system has enough RAM.",
)
status_text = {200: "OK", 206: "Partial Content", status_text = {200: "OK", 206: "Partial Content",
301: "Moved", 302: "Found", 304: "Not Modified", 301: "Moved", 302: "Found", 304: "Not Modified",
+31 -16
View File
@@ -80,6 +80,7 @@ class H2Transport:
self._write_lock = asyncio.Lock() self._write_lock = asyncio.Lock()
self._connect_lock = asyncio.Lock() self._connect_lock = asyncio.Lock()
self._read_task: asyncio.Task | None = None self._read_task: asyncio.Task | None = None
self._conn_generation = 0
# Per-stream tracking # Per-stream tracking
self._streams: dict[int, _StreamState] = {} self._streams: dict[int, _StreamState] = {}
@@ -174,26 +175,34 @@ class H2Transport:
await self._flush() await self._flush()
self._connected = True self._connected = True
self._read_task = asyncio.create_task(self._reader_loop()) self._conn_generation += 1
generation = self._conn_generation
self._read_task = asyncio.create_task(self._reader_loop(generation))
log.info("H2 connected → %s (SNI=%s, TCP_NODELAY=on)", log.info("H2 connected → %s (SNI=%s, TCP_NODELAY=on)",
self.connect_host, sni) self.connect_host, sni)
async def reconnect(self): async def reconnect(self):
"""Close current connection and re-establish.""" """Close current connection and re-establish."""
await self._close_internal() async with self._connect_lock:
await self._do_connect() await self._close_internal()
await self._do_connect()
async def _close_internal(self): async def _close_internal(self):
self._connected = False self._connected = False
if self._read_task: read_task = self._read_task
self._read_task.cancel() self._read_task = None
self._read_task = None if read_task:
read_task.cancel()
await asyncio.gather(read_task, return_exceptions=True)
if self._writer: if self._writer:
try: try:
self._writer.close() writer = self._writer
self._writer = None
writer.close()
await writer.wait_closed()
except Exception: except Exception:
pass pass
self._writer = None self._reader = None
# Wake all pending streams so they can raise # Wake all pending streams so they can raise
for state in self._streams.values(): for state in self._streams.values():
state.error = "Connection closed" state.error = "Connection closed"
@@ -327,7 +336,7 @@ class H2Transport:
# ── Background reader ───────────────────────────────────────── # ── Background reader ─────────────────────────────────────────
async def _reader_loop(self): async def _reader_loop(self, generation: int):
"""Background: read H2 frames, dispatch events to waiting streams.""" """Background: read H2 frames, dispatch events to waiting streams."""
try: try:
while self._connected: while self._connected:
@@ -352,14 +361,20 @@ class H2Transport:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
except Exception as e: except Exception as e:
log.error("H2 reader error: %s", e) if "application data after close notify" in str(e).lower():
log.debug("H2 reader closed after close_notify: %s", e)
else:
log.error("H2 reader error: %s", e)
finally: finally:
self._connected = False if generation != self._conn_generation:
for state in self._streams.values(): log.debug("H2 reader loop ended for stale generation %d", generation)
if not state.done.is_set(): else:
state.error = "Connection lost" self._connected = False
state.done.set() for state in self._streams.values():
log.info("H2 reader loop ended") if not state.done.is_set():
state.error = "Connection lost"
state.done.set()
log.info("H2 reader loop ended")
def _dispatch(self, event): def _dispatch(self, event):
"""Route a single h2 event to its stream.""" """Route a single h2 event to its stream."""
+199 -114
View File
@@ -65,6 +65,23 @@ def _parse_content_length(header_block: bytes) -> int:
return 0 return 0
def _has_unsupported_transfer_encoding(header_block: bytes) -> bool:
"""True when the request uses Transfer-Encoding, which we don't stream."""
for raw_line in header_block.split(b"\r\n"):
name, sep, value = raw_line.partition(b":")
if not sep:
continue
if name.strip().lower() != b"transfer-encoding":
continue
encodings = [
token.strip().lower()
for token in value.decode(errors="replace").split(",")
if token.strip()
]
return any(token != "identity" for token in encodings)
return False
class ResponseCache: class ResponseCache:
"""Simple LRU response cache — avoids repeated relay calls.""" """Simple LRU response cache — avoids repeated relay calls."""
@@ -147,6 +164,14 @@ class ProxyServer:
_GOOGLE_DIRECT_ALLOW_EXACT = GOOGLE_DIRECT_ALLOW_EXACT _GOOGLE_DIRECT_ALLOW_EXACT = GOOGLE_DIRECT_ALLOW_EXACT
_GOOGLE_DIRECT_ALLOW_SUFFIXES = GOOGLE_DIRECT_ALLOW_SUFFIXES _GOOGLE_DIRECT_ALLOW_SUFFIXES = GOOGLE_DIRECT_ALLOW_SUFFIXES
_TRACE_HOST_SUFFIXES = TRACE_HOST_SUFFIXES _TRACE_HOST_SUFFIXES = TRACE_HOST_SUFFIXES
_DOWNLOAD_DEFAULT_EXTS = tuple(sorted(LARGE_FILE_EXTS))
_DOWNLOAD_ACCEPT_MARKERS = (
"application/octet-stream",
"application/zip",
"application/x-bittorrent",
"video/",
"audio/",
)
def __init__(self, config: dict): def __init__(self, config: dict):
self.host = config.get("listen_host", "127.0.0.1") self.host = config.get("listen_host", "127.0.0.1")
@@ -159,6 +184,30 @@ class ProxyServer:
self._cache = ResponseCache(max_mb=CACHE_MAX_MB) self._cache = ResponseCache(max_mb=CACHE_MAX_MB)
self._direct_fail_until: dict[str, float] = {} self._direct_fail_until: dict[str, float] = {}
self._servers: list[asyncio.base_events.Server] = [] self._servers: list[asyncio.base_events.Server] = []
self._client_tasks: set[asyncio.Task] = set()
self._tcp_connect_timeout = self._cfg_float(
config, "tcp_connect_timeout", TCP_CONNECT_TIMEOUT, minimum=1.0,
)
self._download_min_size = self._cfg_int(
config, "chunked_download_min_size", 5 * 1024 * 1024, minimum=0,
)
self._download_chunk_size = self._cfg_int(
config, "chunked_download_chunk_size", 512 * 1024, minimum=64 * 1024,
)
self._download_max_parallel = self._cfg_int(
config, "chunked_download_max_parallel", 8, minimum=1,
)
self._download_max_chunks = self._cfg_int(
config, "chunked_download_max_chunks", 256, minimum=1,
)
self._download_extensions, self._download_any_extension = (
self._normalize_download_extensions(
config.get(
"chunked_download_extensions",
list(self._DOWNLOAD_DEFAULT_EXTS),
)
)
)
# hosts override — DNS fake-map: domain/suffix → IP # hosts override — DNS fake-map: domain/suffix → IP
# Checked before any real DNS lookup; supports exact and suffix matching. # Checked before any real DNS lookup; supports exact and suffix matching.
@@ -198,6 +247,55 @@ class ProxyServer:
# ── Host-policy helpers ─────────────────────────────────────── # ── Host-policy helpers ───────────────────────────────────────
@staticmethod
def _cfg_int(config: dict, key: str, default: int, *, minimum: int = 1) -> int:
try:
value = int(config.get(key, default))
except (TypeError, ValueError):
value = default
return max(minimum, value)
@staticmethod
def _cfg_float(config: dict, key: str, default: float,
*, minimum: float = 0.1) -> float:
try:
value = float(config.get(key, default))
except (TypeError, ValueError):
value = default
return max(minimum, value)
@classmethod
def _normalize_download_extensions(cls, raw) -> tuple[tuple[str, ...], bool]:
values = raw if isinstance(raw, (list, tuple)) else cls._DOWNLOAD_DEFAULT_EXTS
normalized: list[str] = []
any_extension = False
seen: set[str] = set()
for item in values:
ext = str(item).strip().lower()
if not ext:
continue
if ext in {"*", ".*"}:
any_extension = True
continue
if not ext.startswith("."):
ext = "." + ext
if ext not in seen:
seen.add(ext)
normalized.append(ext)
if not normalized and not any_extension:
normalized = list(cls._DOWNLOAD_DEFAULT_EXTS)
return tuple(normalized), any_extension
def _track_current_task(self) -> asyncio.Task | None:
task = asyncio.current_task()
if task is not None:
self._client_tasks.add(task)
return task
def _untrack_task(self, task: asyncio.Task | None) -> None:
if task is not None:
self._client_tasks.discard(task)
@staticmethod @staticmethod
def _load_host_rules(raw) -> tuple[set[str], tuple[str, ...]]: def _load_host_rules(raw) -> tuple[set[str], tuple[str, ...]]:
"""Accept a list of host strings; return (exact_set, suffix_tuple). """Accept a list of host strings; return (exact_set, suffix_tuple).
@@ -352,15 +450,18 @@ class ProxyServer:
self.socks_host, self.socks_port, self.socks_host, self.socks_port,
) )
async with http_srv: try:
if socks_srv: async with http_srv:
async with socks_srv: if socks_srv:
await asyncio.gather( async with socks_srv:
http_srv.serve_forever(), await asyncio.gather(
socks_srv.serve_forever(), http_srv.serve_forever(),
) socks_srv.serve_forever(),
else: )
await http_srv.serve_forever() else:
await http_srv.serve_forever()
except asyncio.CancelledError:
raise
async def stop(self): async def stop(self):
"""Shut down all listeners and release relay resources.""" """Shut down all listeners and release relay resources."""
@@ -375,6 +476,15 @@ class ProxyServer:
except Exception: except Exception:
pass pass
self._servers = [] self._servers = []
current = asyncio.current_task()
client_tasks = [task for task in self._client_tasks if task is not current]
for task in client_tasks:
task.cancel()
if client_tasks:
await asyncio.gather(*client_tasks, return_exceptions=True)
self._client_tasks.clear()
try: try:
await self.fronter.close() await self.fronter.close()
except Exception as exc: except Exception as exc:
@@ -384,6 +494,7 @@ class ProxyServer:
async def _on_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): async def _on_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
addr = writer.get_extra_info("peername") addr = writer.get_extra_info("peername")
task = self._track_current_task()
try: try:
first_line = await asyncio.wait_for(reader.readline(), timeout=30) first_line = await asyncio.wait_for(reader.readline(), timeout=30)
if not first_line: if not first_line:
@@ -400,6 +511,16 @@ class ProxyServer:
if line in (b"\r\n", b"\n", b""): if line in (b"\r\n", b"\n", b""):
break break
if _has_unsupported_transfer_encoding(header_block):
log.warning("Unsupported Transfer-Encoding on client request")
writer.write(
b"HTTP/1.1 501 Not Implemented\r\n"
b"Connection: close\r\n"
b"Content-Length: 0\r\n\r\n"
)
await writer.drain()
return
request_line = first_line.decode(errors="replace").strip() request_line = first_line.decode(errors="replace").strip()
parts = request_line.split(" ", 2) parts = request_line.split(" ", 2)
if len(parts) < 2: if len(parts) < 2:
@@ -412,11 +533,14 @@ class ProxyServer:
else: else:
await self._do_http(header_block, reader, writer) await self._do_http(header_block, reader, writer)
except asyncio.CancelledError:
pass
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.debug("Timeout: %s", addr) log.debug("Timeout: %s", addr)
except Exception as e: except Exception as e:
log.error("Error (%s): %s", addr, e) log.error("Error (%s): %s", addr, e)
finally: finally:
self._untrack_task(task)
try: try:
writer.close() writer.close()
await writer.wait_closed() await writer.wait_closed()
@@ -426,6 +550,7 @@ class ProxyServer:
async def _on_socks_client(self, reader: asyncio.StreamReader, async def _on_socks_client(self, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter): writer: asyncio.StreamWriter):
addr = writer.get_extra_info("peername") addr = writer.get_extra_info("peername")
task = self._track_current_task()
try: try:
header = await asyncio.wait_for(reader.readexactly(2), timeout=15) header = await asyncio.wait_for(reader.readexactly(2), timeout=15)
ver, nmethods = header[0], header[1] ver, nmethods = header[0], header[1]
@@ -475,11 +600,14 @@ class ProxyServer:
except asyncio.IncompleteReadError: except asyncio.IncompleteReadError:
pass pass
except asyncio.CancelledError:
pass
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.debug("SOCKS5 timeout: %s", addr) log.debug("SOCKS5 timeout: %s", addr)
except Exception as e: except Exception as e:
log.error("SOCKS5 error (%s): %s", addr, e) log.error("SOCKS5 error (%s): %s", addr, e)
finally: finally:
self._untrack_task(task)
try: try:
writer.close() writer.close()
await writer.wait_closed() await writer.wait_closed()
@@ -783,7 +911,9 @@ class ProxyServer:
""" """
target_ip = connect_ip or host target_ip = connect_ip or host
try: try:
r_remote, w_remote = await self._open_tcp_connection(target_ip, port, timeout=10) r_remote, w_remote = await self._open_tcp_connection(
target_ip, port, timeout=self._tcp_connect_timeout,
)
except Exception as e: except Exception as e:
log.error("Direct tunnel connect failed (%s via %s): %s", log.error("Direct tunnel connect failed (%s via %s): %s",
host, target_ip, e) host, target_ip, e)
@@ -858,7 +988,7 @@ class ProxyServer:
ssl=ssl_ctx_client, ssl=ssl_ctx_client,
server_hostname=sni_out, server_hostname=sni_out,
), ),
timeout=10, timeout=self._tcp_connect_timeout,
) )
except Exception as e: except Exception as e:
log.error("SNI-rewrite outbound connect failed (%s via %s): %s", log.error("SNI-rewrite outbound connect failed (%s via %s): %s",
@@ -969,6 +1099,15 @@ class ProxyServer:
# Read body # Read body
body = b"" body = b""
if _has_unsupported_transfer_encoding(header_block):
log.warning("Unsupported Transfer-Encoding → %s:%d", host, port)
writer.write(
b"HTTP/1.1 501 Not Implemented\r\n"
b"Connection: close\r\n"
b"Content-Length: 0\r\n\r\n"
)
await writer.drain()
break
length = _parse_content_length(header_block) length = _parse_content_length(header_block)
if length > MAX_REQUEST_BODY_BYTES: if length > MAX_REQUEST_BODY_BYTES:
raise ValueError(f"Request body too large: {length} bytes") raise ValueError(f"Request body too large: {length} bytes")
@@ -1008,25 +1147,7 @@ class ProxyServer:
log.info("MITM → %s %s", method, url) log.info("MITM → %s %s", method, url)
# ── CORS: extract relevant request headers ──────────────────── if await self._maybe_stream_download(method, url, headers, body, writer):
origin = next(
(v for k, v in headers.items() if k.lower() == "origin"), ""
)
acr_method = next(
(v for k, v in headers.items()
if k.lower() == "access-control-request-method"), ""
)
acr_headers = next(
(v for k, v in headers.items()
if k.lower() == "access-control-request-headers"), ""
)
# CORS preflight — respond directly; UrlFetchApp doesn't
# support OPTIONS so forwarding it would always fail.
if method.upper() == "OPTIONS" and acr_method:
log.debug("CORS preflight → %s (responding locally)", url[:60])
writer.write(self._cors_preflight_response(origin, acr_method, acr_headers))
await writer.drain()
continue continue
# Check local cache first (GET only) # Check local cache first (GET only)
@@ -1057,11 +1178,6 @@ class ProxyServer:
self._cache.put(url, response, ttl) self._cache.put(url, response, ttl)
log.debug("Cached (%ds): %s", ttl, url[:60]) log.debug("Cached (%ds): %s", ttl, url[:60])
# Inject permissive CORS headers whenever the browser
# sent an Origin (cross-origin XHR / fetch).
if origin and response:
response = self._inject_cors_headers(response, origin)
self._log_response_summary(url, response) self._log_response_summary(url, response)
writer.write(response) writer.write(response)
@@ -1077,64 +1193,6 @@ class ProxyServer:
log.error("MITM handler error (%s): %s", host, e) log.error("MITM handler error (%s): %s", host, e)
break break
# ── CORS helpers ──────────────────────────────────────────────────────────
@staticmethod
def _cors_preflight_response(origin: str, acr_method: str, acr_headers: str) -> bytes:
"""Return a 204 No Content response that satisfies a CORS preflight."""
allow_origin = origin or "*"
allow_methods = (
f"{acr_method}, GET, POST, PUT, DELETE, PATCH, OPTIONS"
if acr_method else
"GET, POST, PUT, DELETE, PATCH, OPTIONS"
)
allow_headers = acr_headers or "*"
return (
"HTTP/1.1 204 No Content\r\n"
f"Access-Control-Allow-Origin: {allow_origin}\r\n"
f"Access-Control-Allow-Methods: {allow_methods}\r\n"
f"Access-Control-Allow-Headers: {allow_headers}\r\n"
"Access-Control-Allow-Credentials: true\r\n"
"Access-Control-Max-Age: 86400\r\n"
"Vary: Origin\r\n"
"Content-Length: 0\r\n"
"\r\n"
).encode()
@staticmethod
def _inject_cors_headers(response: bytes, origin: str) -> bytes:
"""Inject CORS headers only if the upstream response lacks them.
We must NOT overwrite the origin server's CORS headers: sites like
x.com return carefully-scoped Access-Control-Allow-Headers that list
specific custom headers (e.g. x-csrf-token). Replacing them with
wildcards together with Allow-Credentials: true makes browsers
reject the response (per the Fetch spec, "*" is literal when
credentials are included), which the site then blames on privacy
extensions. So we only fill in what the server omitted.
"""
sep = b"\r\n\r\n"
if sep not in response:
return response
header_section, body = response.split(sep, 1)
lines = header_section.decode(errors="replace").split("\r\n")
existing = {ln.split(":", 1)[0].strip().lower()
for ln in lines if ":" in ln}
# If the upstream already handled CORS, leave it completely alone.
if "access-control-allow-origin" in existing:
return response
# Otherwise inject a minimal, credential-safe set (no wildcards,
# since wildcards combined with credentials are invalid).
allow_origin = origin or "*"
additions = [f"Access-Control-Allow-Origin: {allow_origin}"]
if allow_origin != "*":
additions.append("Access-Control-Allow-Credentials: true")
additions.append("Vary: Origin")
return ("\r\n".join(lines + additions) + "\r\n\r\n").encode() + body
async def _relay_smart(self, method, url, headers, body): async def _relay_smart(self, method, url, headers, body):
"""Choose optimal relay strategy based on request type. """Choose optimal relay strategy based on request type.
@@ -1156,22 +1214,67 @@ class ProxyServer:
# Only probe with Range when the URL looks like a big file. # Only probe with Range when the URL looks like a big file.
if self._is_likely_download(url, headers): if self._is_likely_download(url, headers):
return await self.fronter.relay_parallel( return await self.fronter.relay_parallel(
method, url, headers, body method,
url,
headers,
body,
chunk_size=self._download_chunk_size,
max_parallel=self._download_max_parallel,
max_chunks=self._download_max_chunks,
min_size=self._download_min_size,
) )
return await self.fronter.relay(method, url, headers, body) return await self.fronter.relay(method, url, headers, body)
def _is_likely_download(self, url: str, headers: dict) -> bool: def _is_likely_download(self, url: str, headers: dict) -> bool:
"""Heuristic: is this URL likely a large file download?""" """Heuristic: is this URL likely a large file download?"""
path = url.split("?")[0].lower() path = url.split("?")[0].lower()
for ext in LARGE_FILE_EXTS: if self._download_any_extension:
return True
for ext in self._download_extensions:
if path.endswith(ext): if path.endswith(ext):
return True return True
accept = self._header_value(headers, "accept").lower()
if any(marker in accept for marker in self._DOWNLOAD_ACCEPT_MARKERS):
return True
return False return False
async def _maybe_stream_download(self, method: str, url: str,
headers: dict | None, body: bytes,
writer) -> bool:
if method.upper() != "GET" or body:
return False
if headers:
for key in headers:
if key.lower() == "range":
return False
effective_headers = headers or {}
if not self._is_likely_download(url, effective_headers):
return False
if not self.fronter.stream_download_allowed(url):
return False
return await self.fronter.stream_parallel_download(
url,
effective_headers,
writer,
chunk_size=self._download_chunk_size,
max_parallel=self._download_max_parallel,
max_chunks=self._download_max_chunks,
min_size=self._download_min_size,
)
# ── Plain HTTP forwarding ───────────────────────────────────── # ── Plain HTTP forwarding ─────────────────────────────────────
async def _do_http(self, header_block: bytes, reader, writer): async def _do_http(self, header_block: bytes, reader, writer):
body = b"" body = b""
if _has_unsupported_transfer_encoding(header_block):
log.warning("Unsupported Transfer-Encoding on plain HTTP request")
writer.write(
b"HTTP/1.1 501 Not Implemented\r\n"
b"Connection: close\r\n"
b"Content-Length: 0\r\n\r\n"
)
await writer.drain()
return
length = _parse_content_length(header_block) length = _parse_content_length(header_block)
if length > MAX_REQUEST_BODY_BYTES: if length > MAX_REQUEST_BODY_BYTES:
writer.write(b"HTTP/1.1 413 Content Too Large\r\n\r\n") writer.write(b"HTTP/1.1 413 Content Too Large\r\n\r\n")
@@ -1194,22 +1297,7 @@ class ProxyServer:
k, v = raw_line.decode(errors="replace").split(":", 1) k, v = raw_line.decode(errors="replace").split(":", 1)
headers[k.strip()] = v.strip() headers[k.strip()] = v.strip()
# ── CORS preflight over plain HTTP ──────────────────────────── if await self._maybe_stream_download(method, url, headers, body, writer):
origin = next(
(v for k, v in headers.items() if k.lower() == "origin"), ""
)
acr_method = next(
(v for k, v in headers.items()
if k.lower() == "access-control-request-method"), ""
)
acr_headers_val = next(
(v for k, v in headers.items()
if k.lower() == "access-control-request-headers"), ""
)
if method.upper() == "OPTIONS" and acr_method:
log.debug("CORS preflight (HTTP) → %s (responding locally)", url[:60])
writer.write(self._cors_preflight_response(origin, acr_method, acr_headers_val))
await writer.drain()
return return
# Cache check for GET # Cache check for GET
@@ -1227,9 +1315,6 @@ class ProxyServer:
if ttl > 0: if ttl > 0:
self._cache.put(url, response, ttl) self._cache.put(url, response, ttl)
# Inject CORS headers for cross-origin requests
if origin and response:
response = self._inject_cors_headers(response, origin)
self._log_response_summary(url, response) self._log_response_summary(url, response)
writer.write(response) writer.write(response)