mirror of
https://github.com/masterking32/MasterHttpRelayVPN.git
synced 2026-05-17 21:24:37 +03:00
Improve relay stability and add streamed parallel downloads
This commit is contained in:
@@ -0,0 +1,140 @@
|
|||||||
|
# Suggested PR Title
|
||||||
|
|
||||||
|
Improve relay stability and add streamed parallel downloads for large files
|
||||||
|
|
||||||
|
## Summary
|
||||||
|
|
||||||
|
This PR focuses on making the Apps Script relay safer, more stable, and significantly more usable for large downloads.
|
||||||
|
|
||||||
|
The biggest change is a new streamed parallel download path for likely large files. Instead of buffering the entire file and only handing it to the browser at the end, the proxy can now start sending data to the client immediately while fetching the remaining ranges in parallel. This gives users real browser download progress and fixes the old behavior where the browser looked stuck on loading until the full file had already finished downloading in the background.
|
||||||
|
|
||||||
|
Alongside that, this PR hardens request handling, improves shutdown behavior, adds safer defaults, reduces quota pressure for large downloads, and makes HTTP/2 fallback behavior much more defensive when the transport becomes unstable.
|
||||||
|
|
||||||
|
## What Changed
|
||||||
|
|
||||||
|
### Large-file downloads
|
||||||
|
|
||||||
|
- Added a streamed parallel download path for likely large downloads
|
||||||
|
- Started sending response headers/body to the client immediately instead of waiting for full buffering
|
||||||
|
- Added disk-backed spooling for ordered chunk delivery during streaming downloads
|
||||||
|
- Added per-range validation using `Content-Range` and expected body length checks
|
||||||
|
- Added initial range-probe retries before falling back
|
||||||
|
- Added host cooldown for flaky streaming download targets so repeatedly failing hosts do not keep producing broken partial downloads
|
||||||
|
- Moved range-download traffic off the shared H2 relay path and onto the H1 pool to avoid destabilizing normal relay traffic
|
||||||
|
- Added progress logging in a more readable terminal format:
|
||||||
|
- elapsed time
|
||||||
|
- progress bar
|
||||||
|
- bytes downloaded / total
|
||||||
|
- current download speed
|
||||||
|
- Added `.bin` to default large-download extension heuristics
|
||||||
|
|
||||||
|
### Stability and relay behavior
|
||||||
|
|
||||||
|
- Added configurable relay and connect timeouts
|
||||||
|
- Added configurable response body cap enforcement across buffered relay paths
|
||||||
|
- Made retry behavior safer by limiting retries for non-idempotent requests
|
||||||
|
- Improved GET request coalescing by including representation-relevant headers in the coalescing key
|
||||||
|
- Added bounded chunk/request tuning for large downloads to reduce quota pressure
|
||||||
|
- Added safer defaults for:
|
||||||
|
- `chunked_download_chunk_size`
|
||||||
|
- `chunked_download_max_parallel`
|
||||||
|
- `chunked_download_max_chunks`
|
||||||
|
- Added explicit handling for unsupported request `Transfer-Encoding`
|
||||||
|
|
||||||
|
### HTTP/2 hardening
|
||||||
|
|
||||||
|
- Fixed H2 reconnect/reader lifecycle issues that could cause reconnect storms or stale-reader interference
|
||||||
|
- Added temporary H2 cooldown/circuit-breaker behavior after repeated consecutive H2 failures
|
||||||
|
- Reduced noisy close-notify behavior and improved connection lifecycle handling
|
||||||
|
- Prevented large parallel downloads from destabilizing the shared H2 relay path used by normal traffic
|
||||||
|
|
||||||
|
### Security / defaults / config hygiene
|
||||||
|
|
||||||
|
- Removed permissive CORS behavior that effectively bypassed browser CORS protections
|
||||||
|
- Changed `lan_sharing` to be opt-in by default
|
||||||
|
- Updated setup flow so LAN sharing is explicitly prompted instead of silently inheriting insecure defaults
|
||||||
|
- Synced docs and config examples with actual runtime behavior
|
||||||
|
|
||||||
|
### Shutdown / cleanup
|
||||||
|
|
||||||
|
- Improved shutdown cleanup so active client tasks are tracked, cancelled, and awaited during stop
|
||||||
|
- Reduced `Task was destroyed but it is pending` noise on normal shutdown
|
||||||
|
|
||||||
|
## Config Additions / Changes
|
||||||
|
|
||||||
|
Added or documented the following config options:
|
||||||
|
|
||||||
|
- `relay_timeout`
|
||||||
|
- `tls_connect_timeout`
|
||||||
|
- `tcp_connect_timeout`
|
||||||
|
- `max_response_body_bytes`
|
||||||
|
- `chunked_download_extensions`
|
||||||
|
- `chunked_download_min_size`
|
||||||
|
- `chunked_download_chunk_size`
|
||||||
|
- `chunked_download_max_parallel`
|
||||||
|
- `chunked_download_max_chunks`
|
||||||
|
|
||||||
|
Default download tuning was also made more conservative to improve stability.
|
||||||
|
|
||||||
|
## Bugs Fixed
|
||||||
|
|
||||||
|
- Large downloads appearing only after full buffering instead of progressing in-browser
|
||||||
|
- Browser download UX looking stalled while the proxy was still downloading in the background
|
||||||
|
- Frequent range chunk corruption/acceptance without validating `Content-Range`
|
||||||
|
- H2 reconnect/reader race behavior causing repeated `H2 reader loop ended` / `Connection lost` cascades
|
||||||
|
- Insecure default LAN exposure
|
||||||
|
- Dangerous CORS response injection behavior
|
||||||
|
- Inconsistent response size-cap enforcement
|
||||||
|
- Pending asyncio task noise during shutdown
|
||||||
|
- Silent mishandling of unsupported `Transfer-Encoding` requests
|
||||||
|
|
||||||
|
## Manual Testing
|
||||||
|
|
||||||
|
### Download tests
|
||||||
|
|
||||||
|
- `https://ash-speed.hetzner.com/100MB.bin`
|
||||||
|
- Download completed successfully
|
||||||
|
- Browser showed progressive download behavior
|
||||||
|
- Terminal progress output updated correctly
|
||||||
|
- `https://fsn1-speed.hetzner.com/100MB.bin`
|
||||||
|
- Exposed host-specific flakiness during range streaming
|
||||||
|
- Added host cooldown / fallback protection after repeated streaming failures
|
||||||
|
- `https://fsn1-speed.hetzner.com/1GB.bin`
|
||||||
|
- Download completed successfully
|
||||||
|
- Completed in under 5 minutes in real-world testing
|
||||||
|
- Browser showed real incremental progress instead of waiting for full buffering
|
||||||
|
|
||||||
|
### Runtime / startup / shutdown checks
|
||||||
|
|
||||||
|
- Verified startup with the current config
|
||||||
|
- Verified graceful `Ctrl+C` shutdown behavior after cleanup changes
|
||||||
|
- Ran compile validation:
|
||||||
|
- `python3 -m compileall main.py src setup.py`
|
||||||
|
|
||||||
|
## Result / Impact
|
||||||
|
|
||||||
|
This PR substantially improves the user experience and resilience of the proxy:
|
||||||
|
|
||||||
|
- Large downloads now behave like real downloads in the browser
|
||||||
|
- 100MB and 1GB files can complete successfully through the relay
|
||||||
|
- The system no longer feels stuck during large transfers
|
||||||
|
- Large-download traffic is isolated from the shared H2 relay path, reducing collateral instability
|
||||||
|
- Repeated H2 failures now degrade more safely instead of spamming reconnect errors
|
||||||
|
- Defaults are safer, especially for LAN exposure and browser security behavior
|
||||||
|
- Shutdown is cleaner and less noisy
|
||||||
|
|
||||||
|
## Notes
|
||||||
|
|
||||||
|
- Some hosts are still inherently flakier than others for parallel range downloading through Apps Script, so the host cooldown/fallback behavior is intentional and defensive
|
||||||
|
- This PR prioritizes correctness, safer degradation, and usable large-file behavior over forcing aggressive parallelism on every host
|
||||||
|
|
||||||
|
## Screenshot
|
||||||
|
|
||||||
|
Use the attached screenshot in the PR to show both browser-side progress and terminal-side progress:
|
||||||
|
|
||||||
|
- Local file: `/Users/pouriarc/Downloads/photo_2026-04-23 14.55.20.jpeg`
|
||||||
|
|
||||||
|
Recommended caption:
|
||||||
|
|
||||||
|
> 1GB download progressing normally in the browser while the relay reports live chunk progress and throughput in the terminal.
|
||||||
|
|
||||||
@@ -261,8 +261,17 @@ This project focuses entirely on the **Apps Script** relay — a free Google acc
|
|||||||
|---------|---------|-------------|
|
|---------|---------|-------------|
|
||||||
| `google_ip` | `216.239.38.120` | Google IP address to connect through |
|
| `google_ip` | `216.239.38.120` | Google IP address to connect through |
|
||||||
| `front_domain` | `www.google.com` | Domain shown to the firewall/filter |
|
| `front_domain` | `www.google.com` | Domain shown to the firewall/filter |
|
||||||
| `verify_ssl` | `true` | Verify TLS certificates |
|
| `verify_ssl` | `true` | Verify the TLS certificate on the local fronted connection to Google/CDN |
|
||||||
|
| `relay_timeout` | `25` | Total timeout for one relayed request before it fails |
|
||||||
|
| `tls_connect_timeout` | `15` | Timeout for the proxy's TLS connection to the fronted Google/CDN endpoint |
|
||||||
|
| `tcp_connect_timeout` | `10` | Timeout for direct TCP tunnels and outbound SNI-rewrite connects |
|
||||||
|
| `max_response_body_bytes` | `209715200` | Hard cap for a single relay response body after buffering/decoding |
|
||||||
| `script_ids` | — | Multiple Script IDs for load balancing (array) |
|
| `script_ids` | — | Multiple Script IDs for load balancing (array) |
|
||||||
|
| `chunked_download_extensions` | see [config.example.json](config.example.json) | File extensions that should use parallel range downloading. Supports `".*"` to probe all GET downloads. |
|
||||||
|
| `chunked_download_min_size` | `5242880` | Minimum total file size (5 MB) before range-parallel download stays enabled |
|
||||||
|
| `chunked_download_chunk_size` | `524288` | Per-range chunk size used by parallel downloads |
|
||||||
|
| `chunked_download_max_parallel` | `8` | Maximum simultaneous range requests for one download |
|
||||||
|
| `chunked_download_max_chunks` | `256` | Soft upper bound for total chunk requests; chunk size is raised automatically for very large files |
|
||||||
| `block_hosts` | `[]` | Hosts that must never be tunneled (return HTTP 403). Supports exact names (`ads.example.com`) or leading-dot suffixes (`.doubleclick.net`). |
|
| `block_hosts` | `[]` | Hosts that must never be tunneled (return HTTP 403). Supports exact names (`ads.example.com`) or leading-dot suffixes (`.doubleclick.net`). |
|
||||||
| `bypass_hosts` | `["localhost", ".local", ".lan", ".home.arpa"]` | Hosts that go direct (no MITM, no relay). Useful for LAN resources or sites that break under MITM. |
|
| `bypass_hosts` | `["localhost", ".local", ".lan", ".home.arpa"]` | Hosts that go direct (no MITM, no relay). Useful for LAN resources or sites that break under MITM. |
|
||||||
| `direct_google_exclude` | see [config.example.json](config.example.json) | Google apps that must use the MITM relay path instead of the fast direct tunnel. |
|
| `direct_google_exclude` | see [config.example.json](config.example.json) | Google apps that must use the MITM relay path instead of the fast direct tunnel. |
|
||||||
|
|||||||
+10
-1
@@ -210,8 +210,17 @@ json
|
|||||||
|------|---------------|-------|
|
|------|---------------|-------|
|
||||||
| `google_ip` | `216.239.38.120` | IP مورد استفاده برای مسیر Google |
|
| `google_ip` | `216.239.38.120` | IP مورد استفاده برای مسیر Google |
|
||||||
| `front_domain` | `www.google.com` | دامنهای که فیلتر میبیند |
|
| `front_domain` | `www.google.com` | دامنهای که فیلتر میبیند |
|
||||||
| `verify_ssl` | `true` | بررسی اعتبار TLS |
|
| `verify_ssl` | `true` | بررسی اعتبار TLS فقط برای اتصال fronted محلی به Google/CDN |
|
||||||
|
| `relay_timeout` | `25` | مهلت کل برای هر درخواست relay قبل از fail شدن |
|
||||||
|
| `tls_connect_timeout` | `15` | مهلت اتصال TLS پروکسی به endpoint fronted روی Google/CDN |
|
||||||
|
| `tcp_connect_timeout` | `10` | مهلت اتصال برای tunnel مستقیم و SNI-rewrite |
|
||||||
|
| `max_response_body_bytes` | `209715200` | سقف نهایی برای اندازه body هر پاسخ relay بعد از buffer/decode |
|
||||||
| `script_ids` | - | چند Deployment ID برای load balancing |
|
| `script_ids` | - | چند Deployment ID برای load balancing |
|
||||||
|
| `chunked_download_extensions` | مطابق [config.example.json](config.example.json) | پسوند فایلهایی که باید از دانلود range-parallel استفاده کنند. از `".*"` هم برای probe همه دانلودهای GET پشتیبانی میشود. |
|
||||||
|
| `chunked_download_min_size` | `5242880` | حداقل اندازه کل فایل (۵ مگابایت) برای فعال ماندن دانلود موازی |
|
||||||
|
| `chunked_download_chunk_size` | `524288` | اندازه هر chunk در دانلود موازی |
|
||||||
|
| `chunked_download_max_parallel` | `8` | حداکثر تعداد range request همزمان برای یک دانلود |
|
||||||
|
| `chunked_download_max_chunks` | `256` | سقف نرم برای تعداد کل chunk request ها؛ برای فایلهای خیلی بزرگ اندازه chunk بهصورت خودکار بیشتر میشود |
|
||||||
| `block_hosts` | `[]` | هاستهایی که هرگز نباید tunnel شوند (پاسخ 403). نام دقیق (`ads.example.com`) یا پسوند با نقطهی ابتدایی (`.doubleclick.net`). |
|
| `block_hosts` | `[]` | هاستهایی که هرگز نباید tunnel شوند (پاسخ 403). نام دقیق (`ads.example.com`) یا پسوند با نقطهی ابتدایی (`.doubleclick.net`). |
|
||||||
| `bypass_hosts` | `["localhost", ".local", ".lan", ".home.arpa"]` | هاستهایی که مستقیم میروند (بدون MITM و بدون رله). برای منابع داخلی شبکه یا سایتهایی که با MITM مشکل دارند. |
|
| `bypass_hosts` | `["localhost", ".local", ".lan", ".home.arpa"]` | هاستهایی که مستقیم میروند (بدون MITM و بدون رله). برای منابع داخلی شبکه یا سایتهایی که با MITM مشکل دارند. |
|
||||||
| `direct_google_exclude` | مراجعه به [config.example.json](config.example.json) | اپهای Google که باید از مسیر MITM برای رله استفاده کنند بهجای tunnel مستقیم. |
|
| `direct_google_exclude` | مراجعه به [config.example.json](config.example.json) | اپهای Google که باید از مسیر MITM برای رله استفاده کنند بهجای tunnel مستقیم. |
|
||||||
|
|||||||
+42
-1
@@ -10,8 +10,49 @@
|
|||||||
"socks5_port": 1080,
|
"socks5_port": 1080,
|
||||||
"log_level": "INFO",
|
"log_level": "INFO",
|
||||||
"verify_ssl": true,
|
"verify_ssl": true,
|
||||||
"lan_sharing": true,
|
"lan_sharing": false,
|
||||||
|
"relay_timeout": 25,
|
||||||
|
"tls_connect_timeout": 15,
|
||||||
|
"tcp_connect_timeout": 10,
|
||||||
|
"max_response_body_bytes": 209715200,
|
||||||
"parallel_relay": 1,
|
"parallel_relay": 1,
|
||||||
|
"chunked_download_extensions": [
|
||||||
|
".bin",
|
||||||
|
".zip",
|
||||||
|
".tar",
|
||||||
|
".gz",
|
||||||
|
".bz2",
|
||||||
|
".xz",
|
||||||
|
".7z",
|
||||||
|
".rar",
|
||||||
|
".exe",
|
||||||
|
".msi",
|
||||||
|
".dmg",
|
||||||
|
".deb",
|
||||||
|
".rpm",
|
||||||
|
".apk",
|
||||||
|
".iso",
|
||||||
|
".img",
|
||||||
|
".mp4",
|
||||||
|
".mkv",
|
||||||
|
".avi",
|
||||||
|
".mov",
|
||||||
|
".webm",
|
||||||
|
".mp3",
|
||||||
|
".flac",
|
||||||
|
".wav",
|
||||||
|
".aac",
|
||||||
|
".pdf",
|
||||||
|
".doc",
|
||||||
|
".docx",
|
||||||
|
".ppt",
|
||||||
|
".pptx",
|
||||||
|
".wasm"
|
||||||
|
],
|
||||||
|
"chunked_download_min_size": 5242880,
|
||||||
|
"chunked_download_chunk_size": 524288,
|
||||||
|
"chunked_download_max_parallel": 8,
|
||||||
|
"chunked_download_max_chunks": 256,
|
||||||
"block_hosts": [],
|
"block_hosts": [],
|
||||||
"bypass_hosts": [
|
"bypass_hosts": [
|
||||||
"localhost",
|
"localhost",
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ def main():
|
|||||||
log.info("MITM CA is already trusted.")
|
log.info("MITM CA is already trusted.")
|
||||||
|
|
||||||
# ── LAN sharing configuration ────────────────────────────────────────
|
# ── LAN sharing configuration ────────────────────────────────────────
|
||||||
lan_sharing = config.get("lan_sharing", True)
|
lan_sharing = config.get("lan_sharing", False)
|
||||||
if lan_sharing:
|
if lan_sharing:
|
||||||
# If LAN sharing is enabled and host is still localhost, change to all interfaces
|
# If LAN sharing is enabled and host is still localhost, change to all interfaces
|
||||||
if config.get("listen_host", "127.0.0.1") == "127.0.0.1":
|
if config.get("listen_host", "127.0.0.1") == "127.0.0.1":
|
||||||
|
|||||||
@@ -86,6 +86,15 @@ def load_base_config() -> dict:
|
|||||||
"socks5_port": 1080,
|
"socks5_port": 1080,
|
||||||
"log_level": "INFO",
|
"log_level": "INFO",
|
||||||
"verify_ssl": True,
|
"verify_ssl": True,
|
||||||
|
"lan_sharing": False,
|
||||||
|
"relay_timeout": 25,
|
||||||
|
"tls_connect_timeout": 15,
|
||||||
|
"tcp_connect_timeout": 10,
|
||||||
|
"max_response_body_bytes": 200 * 1024 * 1024,
|
||||||
|
"chunked_download_min_size": 5 * 1024 * 1024,
|
||||||
|
"chunked_download_chunk_size": 512 * 1024,
|
||||||
|
"chunked_download_max_parallel": 8,
|
||||||
|
"chunked_download_max_chunks": 256,
|
||||||
"hosts": {},
|
"hosts": {},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,7 +127,15 @@ def configure_apps_script(cfg: dict) -> dict:
|
|||||||
def configure_network(cfg: dict) -> dict:
|
def configure_network(cfg: dict) -> dict:
|
||||||
print()
|
print()
|
||||||
print(bold("Network settings") + dim(" (press enter to accept defaults)"))
|
print(bold("Network settings") + dim(" (press enter to accept defaults)"))
|
||||||
cfg["listen_host"] = prompt("Listen host", default=str(cfg.get("listen_host", "127.0.0.1")))
|
cfg["lan_sharing"] = prompt_yes_no(
|
||||||
|
"Enable LAN sharing?",
|
||||||
|
default=bool(cfg.get("lan_sharing", False)),
|
||||||
|
)
|
||||||
|
|
||||||
|
default_host = str(cfg.get("listen_host", "127.0.0.1"))
|
||||||
|
if cfg["lan_sharing"] and default_host == "127.0.0.1":
|
||||||
|
default_host = "0.0.0.0"
|
||||||
|
cfg["listen_host"] = prompt("Listen host", default=default_host)
|
||||||
|
|
||||||
port = prompt("HTTP proxy port", default=str(cfg.get("listen_port", 8085)))
|
port = prompt("HTTP proxy port", default=str(cfg.get("listen_port", 8085)))
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -165,6 +165,7 @@ STATIC_EXTS: tuple[str, ...] = (
|
|||||||
".mp3", ".mp4", ".webm", ".wasm", ".avif",
|
".mp3", ".mp4", ".webm", ".wasm", ".avif",
|
||||||
)
|
)
|
||||||
LARGE_FILE_EXTS = frozenset({
|
LARGE_FILE_EXTS = frozenset({
|
||||||
|
".bin",
|
||||||
".zip", ".tar", ".gz", ".bz2", ".xz", ".7z", ".rar",
|
".zip", ".tar", ".gz", ".bz2", ".xz", ".7z", ".rar",
|
||||||
".exe", ".msi", ".dmg", ".deb", ".rpm", ".apk",
|
".exe", ".msi", ".dmg", ".deb", ".rpm", ".apk",
|
||||||
".iso", ".img",
|
".iso", ".img",
|
||||||
|
|||||||
+649
-57
@@ -16,6 +16,7 @@ import logging
|
|||||||
import re
|
import re
|
||||||
import socket
|
import socket
|
||||||
import ssl
|
import ssl
|
||||||
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
@@ -27,6 +28,7 @@ from constants import (
|
|||||||
BATCH_WINDOW_MICRO,
|
BATCH_WINDOW_MICRO,
|
||||||
CONN_TTL,
|
CONN_TTL,
|
||||||
FRONT_SNI_POOL_GOOGLE,
|
FRONT_SNI_POOL_GOOGLE,
|
||||||
|
MAX_RESPONSE_BODY_BYTES,
|
||||||
POOL_MAX,
|
POOL_MAX,
|
||||||
POOL_MIN_IDLE,
|
POOL_MIN_IDLE,
|
||||||
RELAY_TIMEOUT,
|
RELAY_TIMEOUT,
|
||||||
@@ -83,6 +85,18 @@ def _build_sni_pool(front_domain: str, overrides: list | None) -> list[str]:
|
|||||||
|
|
||||||
class DomainFronter:
|
class DomainFronter:
|
||||||
_STATIC_EXTS = STATIC_EXTS
|
_STATIC_EXTS = STATIC_EXTS
|
||||||
|
_H2_FAILURE_COOLDOWN = 60.0
|
||||||
|
_H2_FAILURE_THRESHOLD = 3
|
||||||
|
_DOWNLOAD_STREAM_COOLDOWN = 300.0
|
||||||
|
_COALESCE_VARY_HEADERS = (
|
||||||
|
"accept",
|
||||||
|
"accept-language",
|
||||||
|
"user-agent",
|
||||||
|
"sec-fetch-dest",
|
||||||
|
"sec-fetch-mode",
|
||||||
|
"sec-fetch-site",
|
||||||
|
)
|
||||||
|
_SAFE_RETRY_METHODS = {"GET", "HEAD", "OPTIONS"}
|
||||||
|
|
||||||
def __init__(self, config: dict):
|
def __init__(self, config: dict):
|
||||||
self.connect_host = config.get("google_ip", "216.239.38.120")
|
self.connect_host = config.get("google_ip", "216.239.38.120")
|
||||||
@@ -120,6 +134,16 @@ class DomainFronter:
|
|||||||
|
|
||||||
self.auth_key = config.get("auth_key", "")
|
self.auth_key = config.get("auth_key", "")
|
||||||
self.verify_ssl = config.get("verify_ssl", True)
|
self.verify_ssl = config.get("verify_ssl", True)
|
||||||
|
self._relay_timeout = self._cfg_float(
|
||||||
|
config, "relay_timeout", RELAY_TIMEOUT, minimum=1.0,
|
||||||
|
)
|
||||||
|
self._tls_connect_timeout = self._cfg_float(
|
||||||
|
config, "tls_connect_timeout", TLS_CONNECT_TIMEOUT, minimum=1.0,
|
||||||
|
)
|
||||||
|
self._max_response_body_bytes = self._cfg_int(
|
||||||
|
config, "max_response_body_bytes", MAX_RESPONSE_BODY_BYTES,
|
||||||
|
minimum=1024,
|
||||||
|
)
|
||||||
|
|
||||||
# Connection pool — TTL-based, pre-warmed, with concurrency control
|
# Connection pool — TTL-based, pre-warmed, with concurrency control
|
||||||
self._pool: list[tuple[asyncio.StreamReader, asyncio.StreamWriter, float]] = []
|
self._pool: list[tuple[asyncio.StreamReader, asyncio.StreamWriter, float]] = []
|
||||||
@@ -146,6 +170,9 @@ class DomainFronter:
|
|||||||
|
|
||||||
# Request coalescing — dedup concurrent identical GETs
|
# Request coalescing — dedup concurrent identical GETs
|
||||||
self._coalesce: dict[str, list[asyncio.Future]] = {}
|
self._coalesce: dict[str, list[asyncio.Future]] = {}
|
||||||
|
self._h2_failure_streak = 0
|
||||||
|
self._h2_disabled_until = 0.0
|
||||||
|
self._stream_download_disabled_until: dict[str, float] = {}
|
||||||
|
|
||||||
# HTTP/2 multiplexing — one connection handles all requests
|
# HTTP/2 multiplexing — one connection handles all requests
|
||||||
self._h2 = None
|
self._h2 = None
|
||||||
@@ -173,6 +200,23 @@ class DomainFronter:
|
|||||||
|
|
||||||
# ── helpers ───────────────────────────────────────────────────
|
# ── helpers ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cfg_int(config: dict, key: str, default: int, *, minimum: int = 1) -> int:
|
||||||
|
try:
|
||||||
|
value = int(config.get(key, default))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
value = default
|
||||||
|
return max(minimum, value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cfg_float(config: dict, key: str, default: float,
|
||||||
|
*, minimum: float = 0.1) -> float:
|
||||||
|
try:
|
||||||
|
value = float(config.get(key, default))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
value = default
|
||||||
|
return max(minimum, value)
|
||||||
|
|
||||||
def _ssl_ctx(self) -> ssl.SSLContext:
|
def _ssl_ctx(self) -> ssl.SSLContext:
|
||||||
ctx = ssl.create_default_context()
|
ctx = ssl.create_default_context()
|
||||||
if not self.verify_ssl:
|
if not self.verify_ssl:
|
||||||
@@ -180,6 +224,54 @@ class DomainFronter:
|
|||||||
ctx.verify_mode = ssl.CERT_NONE
|
ctx.verify_mode = ssl.CERT_NONE
|
||||||
return ctx
|
return ctx
|
||||||
|
|
||||||
|
def _h2_available(self) -> bool:
|
||||||
|
return (
|
||||||
|
self._h2 is not None
|
||||||
|
and self._h2.is_connected
|
||||||
|
and time.time() >= self._h2_disabled_until
|
||||||
|
)
|
||||||
|
|
||||||
|
def _record_h2_success(self) -> None:
|
||||||
|
self._h2_failure_streak = 0
|
||||||
|
|
||||||
|
def _record_h2_failure(self, exc: Exception) -> None:
|
||||||
|
self._h2_failure_streak += 1
|
||||||
|
if self._h2_failure_streak >= self._H2_FAILURE_THRESHOLD:
|
||||||
|
self._h2_disabled_until = time.time() + self._H2_FAILURE_COOLDOWN
|
||||||
|
log.warning(
|
||||||
|
"H2 temporarily disabled for %.0fs after %d consecutive failures (%s)",
|
||||||
|
self._H2_FAILURE_COOLDOWN,
|
||||||
|
self._h2_failure_streak,
|
||||||
|
type(exc).__name__,
|
||||||
|
)
|
||||||
|
self._h2_failure_streak = 0
|
||||||
|
|
||||||
|
def _stream_download_allowed(self, url: str) -> bool:
|
||||||
|
host = self._host_key(url)
|
||||||
|
if not host:
|
||||||
|
return True
|
||||||
|
until = self._stream_download_disabled_until.get(host, 0.0)
|
||||||
|
if until > time.time():
|
||||||
|
return False
|
||||||
|
if until:
|
||||||
|
self._stream_download_disabled_until.pop(host, None)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _mark_stream_download_failure(self, url: str, reason: str) -> None:
|
||||||
|
host = self._host_key(url)
|
||||||
|
if not host:
|
||||||
|
return
|
||||||
|
self._stream_download_disabled_until[host] = (
|
||||||
|
time.time() + self._DOWNLOAD_STREAM_COOLDOWN
|
||||||
|
)
|
||||||
|
log.warning(
|
||||||
|
"Parallel streaming disabled for host %s for %.0fs after failure (%s)",
|
||||||
|
host, self._DOWNLOAD_STREAM_COOLDOWN, reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
def stream_download_allowed(self, url: str) -> bool:
|
||||||
|
return self._stream_download_allowed(url)
|
||||||
|
|
||||||
async def _open(self):
|
async def _open(self):
|
||||||
"""Open a TLS connection to the CDN.
|
"""Open a TLS connection to the CDN.
|
||||||
|
|
||||||
@@ -228,7 +320,7 @@ class DomainFronter:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
reader, writer = await asyncio.wait_for(
|
reader, writer = await asyncio.wait_for(
|
||||||
self._open(), timeout=TLS_CONNECT_TIMEOUT
|
self._open(), timeout=self._tls_connect_timeout
|
||||||
)
|
)
|
||||||
# Pool was empty — trigger aggressive background refill
|
# Pool was empty — trigger aggressive background refill
|
||||||
if not self._refilling:
|
if not self._refilling:
|
||||||
@@ -327,6 +419,171 @@ class DomainFronter:
|
|||||||
host = parsed.hostname or url_or_host
|
host = parsed.hostname or url_or_host
|
||||||
return host.lower().rstrip(".")
|
return host.lower().rstrip(".")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _coalesce_key(cls, url: str, headers: dict | None) -> str:
|
||||||
|
key = [url]
|
||||||
|
if headers:
|
||||||
|
lowered = {str(k).lower(): str(v) for k, v in headers.items()}
|
||||||
|
for name in cls._COALESCE_VARY_HEADERS:
|
||||||
|
value = lowered.get(name)
|
||||||
|
if value:
|
||||||
|
key.append(f"{name}={value}")
|
||||||
|
return "\n".join(key)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _retry_attempts_for_payload(cls, payload: dict) -> int:
|
||||||
|
method = str(payload.get("m", "GET")).upper()
|
||||||
|
return 2 if method in cls._SAFE_RETRY_METHODS else 1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _render_streaming_headers(resp_headers: dict, total_size: int) -> bytes:
|
||||||
|
lines = ["HTTP/1.1 200 OK"]
|
||||||
|
skip = {
|
||||||
|
"transfer-encoding",
|
||||||
|
"connection",
|
||||||
|
"keep-alive",
|
||||||
|
"content-length",
|
||||||
|
"content-range",
|
||||||
|
}
|
||||||
|
for key, value in resp_headers.items():
|
||||||
|
if key.lower() in skip:
|
||||||
|
continue
|
||||||
|
lines.append(f"{key}: {value}")
|
||||||
|
lines.append(f"Content-Length: {total_size}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("")
|
||||||
|
return "\r\n".join(lines).encode()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_content_range(value: str) -> tuple[int, int, int] | None:
|
||||||
|
match = re.match(r"^\s*bytes\s+(\d+)-(\d+)/(\d+)\s*$", value or "")
|
||||||
|
if not match:
|
||||||
|
return None
|
||||||
|
start, end, total = (int(group) for group in match.groups())
|
||||||
|
if start < 0 or end < start or total <= end:
|
||||||
|
return None
|
||||||
|
return start, end, total
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_range_response(cls, status: int, resp_headers: dict,
|
||||||
|
body: bytes, start_off: int,
|
||||||
|
end_off: int,
|
||||||
|
total_size: int | None = None) -> str | None:
|
||||||
|
if status != 206:
|
||||||
|
return f"status {status}"
|
||||||
|
parsed = cls._parse_content_range(resp_headers.get("content-range", ""))
|
||||||
|
if not parsed:
|
||||||
|
return "missing/invalid Content-Range"
|
||||||
|
got_start, got_end, got_total = parsed
|
||||||
|
if got_start != start_off or got_end != end_off:
|
||||||
|
return f"Content-Range mismatch {got_start}-{got_end}"
|
||||||
|
if total_size is not None and got_total != total_size:
|
||||||
|
return f"Content-Range total mismatch {got_total}/{total_size}"
|
||||||
|
expected = end_off - start_off + 1
|
||||||
|
if len(body) != expected:
|
||||||
|
return f"short chunk {len(body)}/{expected} B"
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _spool_write(file_obj, offset: int, data: bytes) -> None:
|
||||||
|
file_obj.seek(offset)
|
||||||
|
file_obj.write(data)
|
||||||
|
file_obj.flush()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _spool_read(file_obj, offset: int, size: int) -> bytes:
|
||||||
|
file_obj.seek(offset)
|
||||||
|
return file_obj.read(size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_bytes_human(num_bytes: int) -> str:
|
||||||
|
value = float(max(0, num_bytes))
|
||||||
|
units = ("B", "KiB", "MiB", "GiB", "TiB")
|
||||||
|
unit = units[0]
|
||||||
|
for unit in units:
|
||||||
|
if value < 1024.0 or unit == units[-1]:
|
||||||
|
break
|
||||||
|
value /= 1024.0
|
||||||
|
if unit == "B":
|
||||||
|
return f"{int(value)} {unit}"
|
||||||
|
return f"{value:.1f} {unit}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _format_elapsed_short(seconds: float) -> str:
|
||||||
|
total = max(0, int(seconds))
|
||||||
|
minutes, secs = divmod(total, 60)
|
||||||
|
hours, minutes = divmod(minutes, 60)
|
||||||
|
if hours:
|
||||||
|
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
|
||||||
|
return f"{minutes:02d}:{secs:02d}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _render_progress_bar(done: int, total: int, width: int = 34) -> str:
|
||||||
|
if total <= 0:
|
||||||
|
return "[" + ("-" * width) + "]"
|
||||||
|
ratio = max(0.0, min(1.0, done / total))
|
||||||
|
filled = min(width, int(round(ratio * width)))
|
||||||
|
return "[" + ("#" * filled) + ("-" * (width - filled)) + "]"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _progress_line(cls, *, elapsed: float, done: int, total: int,
|
||||||
|
speed_bytes_per_sec: float) -> str:
|
||||||
|
return (
|
||||||
|
f"[{cls._format_elapsed_short(elapsed)}] "
|
||||||
|
f"{cls._render_progress_bar(done, total)} "
|
||||||
|
f"{cls._format_bytes_human(done)} / {cls._format_bytes_human(total)} "
|
||||||
|
f"({cls._format_bytes_human(int(speed_bytes_per_sec))}/s)"
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _relay_payload_h1(self, payload: dict) -> bytes:
|
||||||
|
attempts = self._retry_attempts_for_payload(payload)
|
||||||
|
async with self._semaphore:
|
||||||
|
for attempt in range(attempts):
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(
|
||||||
|
self._relay_single(payload), timeout=self._relay_timeout,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
if attempt < attempts - 1:
|
||||||
|
log.debug(
|
||||||
|
"H1 relay attempt %d failed (%s: %s), retrying",
|
||||||
|
attempt + 1, type(exc).__name__, exc,
|
||||||
|
)
|
||||||
|
await self._flush_pool()
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _range_probe(self, url: str, headers: dict, start_off: int,
|
||||||
|
end_off: int, *, max_tries: int = 3) -> bytes:
|
||||||
|
probe_headers = dict(headers) if headers else {}
|
||||||
|
probe_headers["Range"] = f"bytes={start_off}-{end_off}"
|
||||||
|
probe_payload = self._build_payload("GET", url, probe_headers, b"")
|
||||||
|
last_raw = b""
|
||||||
|
last_status = 0
|
||||||
|
for attempt in range(max_tries):
|
||||||
|
try:
|
||||||
|
last_raw = await self._relay_payload_h1(probe_payload)
|
||||||
|
except Exception as exc:
|
||||||
|
if attempt == max_tries - 1:
|
||||||
|
raise
|
||||||
|
log.warning(
|
||||||
|
"Initial range probe %d-%d retry %d/%d failed: %r",
|
||||||
|
start_off, end_off, attempt + 1, max_tries, exc,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.3 * (attempt + 1))
|
||||||
|
continue
|
||||||
|
|
||||||
|
last_status, _, _ = self._split_raw_response(last_raw)
|
||||||
|
if last_status == 206 or last_status < 500:
|
||||||
|
return last_raw
|
||||||
|
if attempt < max_tries - 1:
|
||||||
|
log.warning(
|
||||||
|
"Initial range probe %d-%d retry %d/%d: status %d",
|
||||||
|
start_off, end_off, attempt + 1, max_tries, last_status,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(0.3 * (attempt + 1))
|
||||||
|
return last_raw
|
||||||
|
|
||||||
# ── Per-host stats ────────────────────────────────────────────
|
# ── Per-host stats ────────────────────────────────────────────
|
||||||
|
|
||||||
def _record_site(self, url: str, bytes_: int, latency_ns: int,
|
def _record_site(self, url: str, bytes_: int, latency_ns: int,
|
||||||
@@ -521,12 +778,17 @@ class DomainFronter:
|
|||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
"""Cancel background tasks and close all pooled / H2 connections."""
|
"""Cancel background tasks and close all pooled / H2 connections."""
|
||||||
for task in list(self._bg_tasks):
|
tasks = list(self._bg_tasks)
|
||||||
|
for task in tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
if self._bg_tasks:
|
if tasks:
|
||||||
self._spawn(self._prewarm_script())
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
if self._keepalive_task is None or self._keepalive_task.done():
|
self._bg_tasks.clear()
|
||||||
self._keepalive_task = self._spawn
|
|
||||||
|
self._warm_task = None
|
||||||
|
self._maintenance_task = None
|
||||||
|
self._stats_task = None
|
||||||
|
self._keepalive_task = None
|
||||||
|
|
||||||
await self._flush_pool()
|
await self._flush_pool()
|
||||||
|
|
||||||
@@ -538,18 +800,25 @@ class DomainFronter:
|
|||||||
|
|
||||||
async def _h2_connect(self):
|
async def _h2_connect(self):
|
||||||
"""Connect the HTTP/2 transport in background."""
|
"""Connect the HTTP/2 transport in background."""
|
||||||
|
if self._h2 is None:
|
||||||
|
return
|
||||||
|
if time.time() < self._h2_disabled_until:
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
await self._h2.ensure_connected()
|
await self._h2.ensure_connected()
|
||||||
|
self._record_h2_success()
|
||||||
log.info("H2 multiplexing active — one conn handles all requests")
|
log.info("H2 multiplexing active — one conn handles all requests")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self._record_h2_failure(e)
|
||||||
log.warning("H2 connect failed (%s), using H1 pool fallback", e)
|
log.warning("H2 connect failed (%s), using H1 pool fallback", e)
|
||||||
|
|
||||||
async def _h2_connect_and_warm(self):
|
async def _h2_connect_and_warm(self):
|
||||||
"""Connect H2, pre-warm the Apps Script container, start keepalive."""
|
"""Connect H2, pre-warm the Apps Script container, start keepalive."""
|
||||||
await self._h2_connect()
|
await self._h2_connect()
|
||||||
if self._h2 and self._h2.is_connected:
|
if self._h2_available():
|
||||||
asyncio.create_task(self._prewarm_script())
|
self._spawn(self._prewarm_script())
|
||||||
asyncio.create_task(self._keepalive_loop())
|
if self._keepalive_task is None or self._keepalive_task.done():
|
||||||
|
self._keepalive_task = self._spawn(self._keepalive_loop())
|
||||||
|
|
||||||
async def _prewarm_script(self):
|
async def _prewarm_script(self):
|
||||||
"""Pre-warm Apps Script and detect /dev fast path (no redirect)."""
|
"""Pre-warm Apps Script and detect /dev fast path (no redirect)."""
|
||||||
@@ -602,10 +871,12 @@ class DomainFronter:
|
|||||||
try:
|
try:
|
||||||
await asyncio.sleep(240) # 4 minutes — saves ~90 quota hits/day vs 180s
|
await asyncio.sleep(240) # 4 minutes — saves ~90 quota hits/day vs 180s
|
||||||
# Google's container timeout is ~5 min idle
|
# Google's container timeout is ~5 min idle
|
||||||
if not self._h2 or not self._h2.is_connected:
|
if not self._h2_available():
|
||||||
try:
|
try:
|
||||||
await self._h2.reconnect()
|
await self._h2.reconnect()
|
||||||
except Exception:
|
self._record_h2_success()
|
||||||
|
except Exception as exc:
|
||||||
|
self._record_h2_failure(exc)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# H2 PING to keep connection alive
|
# H2 PING to keep connection alive
|
||||||
@@ -681,7 +952,9 @@ class DomainFronter:
|
|||||||
has_range = True
|
has_range = True
|
||||||
break
|
break
|
||||||
if method == "GET" and not body and not has_range:
|
if method == "GET" and not body and not has_range:
|
||||||
result = await self._coalesced_submit(url, payload)
|
result = await self._coalesced_submit(
|
||||||
|
self._coalesce_key(url, headers), payload,
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
result = await self._batch_submit(payload)
|
result = await self._batch_submit(payload)
|
||||||
@@ -693,7 +966,7 @@ class DomainFronter:
|
|||||||
latency_ns = int((time.perf_counter() - t0) * 1e9)
|
latency_ns = int((time.perf_counter() - t0) * 1e9)
|
||||||
self._record_site(url, len(result), latency_ns, errored)
|
self._record_site(url, len(result), latency_ns, errored)
|
||||||
|
|
||||||
async def _coalesced_submit(self, url: str, payload: dict) -> bytes:
|
async def _coalesced_submit(self, key: str, payload: dict) -> bytes:
|
||||||
"""Dedup concurrent requests for the same URL (no Range header).
|
"""Dedup concurrent requests for the same URL (no Range header).
|
||||||
|
|
||||||
Uses `_batch_lock` to atomically check-and-append, preventing a
|
Uses `_batch_lock` to atomically check-and-append, preventing a
|
||||||
@@ -702,14 +975,14 @@ class DomainFronter:
|
|||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
async with self._batch_lock:
|
async with self._batch_lock:
|
||||||
waiters = self._coalesce.get(url)
|
waiters = self._coalesce.get(key)
|
||||||
if waiters is not None:
|
if waiters is not None:
|
||||||
future = loop.create_future()
|
future = loop.create_future()
|
||||||
waiters.append(future)
|
waiters.append(future)
|
||||||
log.debug("Coalesced request: %s", url[:60])
|
log.debug("Coalesced request: %s", key.split("\n", 1)[0][:60])
|
||||||
waiting = True
|
waiting = True
|
||||||
else:
|
else:
|
||||||
self._coalesce[url] = []
|
self._coalesce[key] = []
|
||||||
waiting = False
|
waiting = False
|
||||||
|
|
||||||
if waiting:
|
if waiting:
|
||||||
@@ -719,14 +992,14 @@ class DomainFronter:
|
|||||||
result = await self._batch_submit(payload)
|
result = await self._batch_submit(payload)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
async with self._batch_lock:
|
async with self._batch_lock:
|
||||||
waiters = self._coalesce.pop(url, [])
|
waiters = self._coalesce.pop(key, [])
|
||||||
for f in waiters:
|
for f in waiters:
|
||||||
if not f.done():
|
if not f.done():
|
||||||
f.set_exception(e)
|
f.set_exception(e)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async with self._batch_lock:
|
async with self._batch_lock:
|
||||||
waiters = self._coalesce.pop(url, [])
|
waiters = self._coalesce.pop(key, [])
|
||||||
for f in waiters:
|
for f in waiters:
|
||||||
if not f.done():
|
if not f.done():
|
||||||
f.set_result(result)
|
f.set_result(result)
|
||||||
@@ -734,8 +1007,10 @@ class DomainFronter:
|
|||||||
|
|
||||||
async def relay_parallel(self, method: str, url: str,
|
async def relay_parallel(self, method: str, url: str,
|
||||||
headers: dict, body: bytes = b"",
|
headers: dict, body: bytes = b"",
|
||||||
chunk_size: int = 256 * 1024,
|
chunk_size: int = 512 * 1024,
|
||||||
max_parallel: int = 16) -> bytes:
|
max_parallel: int = 8,
|
||||||
|
max_chunks: int = 256,
|
||||||
|
min_size: int = 0) -> bytes:
|
||||||
"""Relay with parallel range acceleration for large downloads.
|
"""Relay with parallel range acceleration for large downloads.
|
||||||
|
|
||||||
Strategy:
|
Strategy:
|
||||||
@@ -747,17 +1022,15 @@ class DomainFronter:
|
|||||||
|
|
||||||
Since each Apps Script call takes ~2s regardless of payload size,
|
Since each Apps Script call takes ~2s regardless of payload size,
|
||||||
we use:
|
we use:
|
||||||
- 256 KB chunks (safe under Apps Script response limit)
|
- 512 KB chunks (fewer relay calls, lower quota pressure)
|
||||||
- Up to 16 chunks in flight at once via H2 multiplexing
|
- Up to 8 chunks in flight at once via H2 multiplexing
|
||||||
- Aggregate throughput of ~2 MB per round-trip (~2-3s)
|
- Aggregate throughput of ~2 MB per round-trip (~2-3s)
|
||||||
"""
|
"""
|
||||||
if method != "GET" or body:
|
if method != "GET" or body:
|
||||||
return await self.relay(method, url, headers, body)
|
return await self.relay(method, url, headers, body)
|
||||||
|
|
||||||
# Probe: first chunk with Range header
|
# Probe: first chunk with Range header
|
||||||
range_headers = dict(headers) if headers else {}
|
first_resp = await self._range_probe(url, headers, 0, chunk_size - 1)
|
||||||
range_headers["Range"] = f"bytes=0-{chunk_size - 1}"
|
|
||||||
first_resp = await self.relay("GET", url, range_headers, b"")
|
|
||||||
|
|
||||||
status, resp_hdrs, resp_body = self._split_raw_response(first_resp)
|
status, resp_hdrs, resp_body = self._split_raw_response(first_resp)
|
||||||
|
|
||||||
@@ -768,13 +1041,40 @@ class DomainFronter:
|
|||||||
return first_resp
|
return first_resp
|
||||||
|
|
||||||
# Parse total size from Content-Range: "bytes 0-262143/1048576"
|
# Parse total size from Content-Range: "bytes 0-262143/1048576"
|
||||||
content_range = resp_hdrs.get("content-range", "")
|
parsed_range = self._parse_content_range(resp_hdrs.get("content-range", ""))
|
||||||
m = re.search(r"/(\d+)", content_range)
|
if not parsed_range:
|
||||||
if not m:
|
|
||||||
# Can't parse — downgrade to 200 so the client (which sent a
|
# Can't parse — downgrade to 200 so the client (which sent a
|
||||||
# plain GET) doesn't get confused by 206 + Content-Range.
|
# plain GET) doesn't get confused by 206 + Content-Range.
|
||||||
return self._rewrite_206_to_200(first_resp)
|
return self._rewrite_206_to_200(first_resp)
|
||||||
total_size = int(m.group(1))
|
first_start, first_end, total_size = parsed_range
|
||||||
|
first_err = self._validate_range_response(
|
||||||
|
status, resp_hdrs, resp_body, first_start, first_end, total_size,
|
||||||
|
)
|
||||||
|
if first_start != 0 or first_err:
|
||||||
|
return self._rewrite_206_to_200(first_resp)
|
||||||
|
if total_size > self._max_response_body_bytes:
|
||||||
|
return self._error_response(
|
||||||
|
502,
|
||||||
|
"Relay response exceeds cap "
|
||||||
|
f"({self._max_response_body_bytes} bytes). "
|
||||||
|
"Increase max_response_body_bytes if your system has enough RAM.",
|
||||||
|
)
|
||||||
|
if min_size > 0 and total_size < min_size:
|
||||||
|
return self._rewrite_206_to_200(first_resp)
|
||||||
|
if max_chunks > 0:
|
||||||
|
required_chunk_size = max(
|
||||||
|
chunk_size,
|
||||||
|
(total_size + max_chunks - 1) // max_chunks,
|
||||||
|
)
|
||||||
|
if required_chunk_size != chunk_size:
|
||||||
|
log.info(
|
||||||
|
"Parallel download tuning: chunk size raised from %d KB to %d KB "
|
||||||
|
"to keep request count under %d",
|
||||||
|
chunk_size // 1024,
|
||||||
|
required_chunk_size // 1024,
|
||||||
|
max_chunks,
|
||||||
|
)
|
||||||
|
chunk_size = required_chunk_size
|
||||||
|
|
||||||
# Small file: probe already fetched it all. MUST rewrite to 200
|
# Small file: probe already fetched it all. MUST rewrite to 200
|
||||||
# because the client never sent a Range header — a stray 206 here
|
# because the client never sent a Range header — a stray 206 here
|
||||||
@@ -795,22 +1095,54 @@ class DomainFronter:
|
|||||||
|
|
||||||
# Concurrency-limited parallel fetch
|
# Concurrency-limited parallel fetch
|
||||||
sem = asyncio.Semaphore(max_parallel)
|
sem = asyncio.Semaphore(max_parallel)
|
||||||
|
progress_lock = asyncio.Lock()
|
||||||
|
completed_chunks = 1 # first range probe already succeeded
|
||||||
|
completed_bytes = len(resp_body)
|
||||||
|
last_progress_log = time.perf_counter()
|
||||||
|
total_chunks = len(ranges) + 1
|
||||||
|
total_bytes = total_size
|
||||||
|
|
||||||
async def fetch_range(s, e, max_tries: int = 3):
|
async def fetch_range(s, e, max_tries: int = 3):
|
||||||
|
nonlocal completed_chunks, completed_bytes, last_progress_log
|
||||||
async with sem:
|
async with sem:
|
||||||
rh_base = dict(headers) if headers else {}
|
rh_base = dict(headers) if headers else {}
|
||||||
rh_base["Range"] = f"bytes={s}-{e}"
|
rh_base["Range"] = f"bytes={s}-{e}"
|
||||||
|
payload = self._build_payload("GET", url, rh_base, b"")
|
||||||
expected = e - s + 1
|
expected = e - s + 1
|
||||||
last_err = None
|
last_err = None
|
||||||
for attempt in range(max_tries):
|
for attempt in range(max_tries):
|
||||||
try:
|
try:
|
||||||
raw = await self.relay("GET", url, rh_base, b"")
|
raw = await self._relay_payload_h1(payload)
|
||||||
_, _, chunk_body = self._split_raw_response(raw)
|
chunk_status, chunk_headers, chunk_body = self._split_raw_response(raw)
|
||||||
if len(chunk_body) == expected:
|
err = self._validate_range_response(
|
||||||
return chunk_body
|
chunk_status, chunk_headers, chunk_body,
|
||||||
last_err = (
|
s, e, total_size,
|
||||||
f"short chunk {len(chunk_body)}/{expected} B"
|
|
||||||
)
|
)
|
||||||
|
if err is None:
|
||||||
|
now = time.perf_counter()
|
||||||
|
async with progress_lock:
|
||||||
|
completed_chunks += 1
|
||||||
|
completed_bytes += len(chunk_body)
|
||||||
|
should_log = (
|
||||||
|
completed_chunks == total_chunks
|
||||||
|
or (now - last_progress_log) >= 5.0
|
||||||
|
)
|
||||||
|
if should_log:
|
||||||
|
elapsed = max(0.001, now - t0)
|
||||||
|
speed_bps = completed_bytes / elapsed
|
||||||
|
log.info(
|
||||||
|
"Parallel download progress: %s [%d/%d chunks]",
|
||||||
|
self._progress_line(
|
||||||
|
elapsed=elapsed,
|
||||||
|
done=completed_bytes,
|
||||||
|
total=total_bytes,
|
||||||
|
speed_bytes_per_sec=speed_bps,
|
||||||
|
),
|
||||||
|
completed_chunks, total_chunks,
|
||||||
|
)
|
||||||
|
last_progress_log = now
|
||||||
|
return chunk_body
|
||||||
|
last_err = err
|
||||||
except Exception as e_:
|
except Exception as e_:
|
||||||
last_err = repr(e_)
|
last_err = repr(e_)
|
||||||
log.warning("Range %d-%d retry %d/%d: %s",
|
log.warning("Range %d-%d retry %d/%d: %s",
|
||||||
@@ -837,8 +1169,15 @@ class DomainFronter:
|
|||||||
|
|
||||||
full_body = b"".join(parts)
|
full_body = b"".join(parts)
|
||||||
kbs = (len(full_body) / 1024) / elapsed if elapsed > 0 else 0
|
kbs = (len(full_body) / 1024) / elapsed if elapsed > 0 else 0
|
||||||
log.info("Parallel download complete: %d B in %.2fs = %.1f KB/s",
|
log.info(
|
||||||
len(full_body), elapsed, kbs)
|
"Parallel download complete: %s",
|
||||||
|
self._progress_line(
|
||||||
|
elapsed=elapsed,
|
||||||
|
done=len(full_body),
|
||||||
|
total=len(full_body),
|
||||||
|
speed_bytes_per_sec=kbs * 1024,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
# Return as 200 OK (client sent a normal GET)
|
# Return as 200 OK (client sent a normal GET)
|
||||||
result = f"HTTP/1.1 200 OK\r\n"
|
result = f"HTTP/1.1 200 OK\r\n"
|
||||||
@@ -851,6 +1190,219 @@ class DomainFronter:
|
|||||||
result += "\r\n"
|
result += "\r\n"
|
||||||
return result.encode() + full_body
|
return result.encode() + full_body
|
||||||
|
|
||||||
|
async def stream_parallel_download(self, url: str, headers: dict,
|
||||||
|
writer,
|
||||||
|
*,
|
||||||
|
chunk_size: int = 512 * 1024,
|
||||||
|
max_parallel: int = 8,
|
||||||
|
max_chunks: int = 256,
|
||||||
|
min_size: int = 0) -> bool:
|
||||||
|
"""Stream a large range-capable download to the client incrementally.
|
||||||
|
|
||||||
|
Returns False when the target should fall back to the normal relay
|
||||||
|
path (for example no range support or the file is too small).
|
||||||
|
Returns True once this method has taken ownership of the client
|
||||||
|
response, even if the stream later aborts.
|
||||||
|
"""
|
||||||
|
first_resp = await self._range_probe(url, headers, 0, chunk_size - 1)
|
||||||
|
|
||||||
|
status, resp_hdrs, resp_body = self._split_raw_response(first_resp)
|
||||||
|
if status != 206:
|
||||||
|
log.info(
|
||||||
|
"Streaming download fallback: initial probe returned %s for %s",
|
||||||
|
status, url[:80],
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
|
||||||
|
parsed_range = self._parse_content_range(resp_hdrs.get("content-range", ""))
|
||||||
|
if not parsed_range:
|
||||||
|
log.info(
|
||||||
|
"Streaming download fallback: missing/invalid Content-Range for %s",
|
||||||
|
url[:80],
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
first_start, first_end, total_size = parsed_range
|
||||||
|
first_err = self._validate_range_response(
|
||||||
|
status, resp_hdrs, resp_body, first_start, first_end, total_size,
|
||||||
|
)
|
||||||
|
if first_start != 0 or first_err:
|
||||||
|
log.info(
|
||||||
|
"Streaming download fallback: invalid first range (%s) for %s",
|
||||||
|
first_err or f"start={first_start}",
|
||||||
|
url[:80],
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if min_size > 0 and total_size < min_size:
|
||||||
|
log.info(
|
||||||
|
"Streaming download fallback: file too small (%d < %d) for %s",
|
||||||
|
total_size, min_size, url[:80],
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if max_chunks > 0:
|
||||||
|
required_chunk_size = max(
|
||||||
|
chunk_size,
|
||||||
|
(total_size + max_chunks - 1) // max_chunks,
|
||||||
|
)
|
||||||
|
if required_chunk_size != chunk_size:
|
||||||
|
log.info(
|
||||||
|
"Parallel download tuning: chunk size raised from %d KB to %d KB "
|
||||||
|
"to keep request count under %d",
|
||||||
|
chunk_size // 1024,
|
||||||
|
required_chunk_size // 1024,
|
||||||
|
max_chunks,
|
||||||
|
)
|
||||||
|
chunk_size = required_chunk_size
|
||||||
|
|
||||||
|
if total_size <= chunk_size or len(resp_body) >= total_size:
|
||||||
|
writer.write(self._render_streaming_headers(resp_hdrs, total_size))
|
||||||
|
writer.write(resp_body)
|
||||||
|
await writer.drain()
|
||||||
|
return True
|
||||||
|
|
||||||
|
ranges = []
|
||||||
|
start = len(resp_body)
|
||||||
|
while start < total_size:
|
||||||
|
end = min(start + chunk_size - 1, total_size - 1)
|
||||||
|
ranges.append((start, end))
|
||||||
|
start = end + 1
|
||||||
|
|
||||||
|
log.info("Parallel streaming download: %d bytes, %d chunks of %d KB",
|
||||||
|
total_size, len(ranges) + 1, chunk_size // 1024)
|
||||||
|
|
||||||
|
temp_file = tempfile.TemporaryFile(prefix="mhrvpn_dl_")
|
||||||
|
file_lock = asyncio.Lock()
|
||||||
|
sem = asyncio.Semaphore(max_parallel)
|
||||||
|
cancel_event = asyncio.Event()
|
||||||
|
tasks: list[asyncio.Task] = []
|
||||||
|
ready = [asyncio.Event() for _ in ranges]
|
||||||
|
errors: list[Exception | None] = [None for _ in ranges]
|
||||||
|
delivered_chunks = 1
|
||||||
|
delivered_bytes = len(resp_body)
|
||||||
|
total_chunks = len(ranges) + 1
|
||||||
|
last_progress_log = time.perf_counter()
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
|
async def _write_progress(force: bool = False) -> None:
|
||||||
|
nonlocal last_progress_log
|
||||||
|
now = time.perf_counter()
|
||||||
|
if not force and (now - last_progress_log) < 5.0:
|
||||||
|
return
|
||||||
|
elapsed = max(0.001, now - t0)
|
||||||
|
speed_bps = delivered_bytes / elapsed
|
||||||
|
log.info(
|
||||||
|
"Parallel download progress: %s [%d/%d chunks]",
|
||||||
|
self._progress_line(
|
||||||
|
elapsed=elapsed,
|
||||||
|
done=delivered_bytes,
|
||||||
|
total=total_size,
|
||||||
|
speed_bytes_per_sec=speed_bps,
|
||||||
|
),
|
||||||
|
delivered_chunks, total_chunks,
|
||||||
|
)
|
||||||
|
last_progress_log = now
|
||||||
|
|
||||||
|
async def fetch_range(index: int, start_off: int, end_off: int,
|
||||||
|
max_tries: int = 3) -> None:
|
||||||
|
async with sem:
|
||||||
|
base_headers = dict(headers) if headers else {}
|
||||||
|
base_headers["Range"] = f"bytes={start_off}-{end_off}"
|
||||||
|
payload = self._build_payload("GET", url, base_headers, b"")
|
||||||
|
expected = end_off - start_off + 1
|
||||||
|
last_err = "unknown"
|
||||||
|
try:
|
||||||
|
for attempt in range(max_tries):
|
||||||
|
if cancel_event.is_set():
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
raw = await self._relay_payload_h1(payload)
|
||||||
|
chunk_status, chunk_headers, chunk_body = self._split_raw_response(raw)
|
||||||
|
err = self._validate_range_response(
|
||||||
|
chunk_status, chunk_headers, chunk_body,
|
||||||
|
start_off, end_off, total_size,
|
||||||
|
)
|
||||||
|
if err is None:
|
||||||
|
async with file_lock:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self._spool_write, temp_file, start_off, chunk_body,
|
||||||
|
)
|
||||||
|
ready[index].set()
|
||||||
|
return
|
||||||
|
last_err = err
|
||||||
|
except Exception as exc:
|
||||||
|
last_err = repr(exc)
|
||||||
|
if cancel_event.is_set():
|
||||||
|
return
|
||||||
|
log.warning("Range %d-%d retry %d/%d: %s",
|
||||||
|
start_off, end_off, attempt + 1, max_tries, last_err)
|
||||||
|
await asyncio.sleep(0.3 * (attempt + 1))
|
||||||
|
errors[index] = RuntimeError(
|
||||||
|
f"chunk {start_off}-{end_off} failed after {max_tries} tries: {last_err}"
|
||||||
|
)
|
||||||
|
ready[index].set()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
|
||||||
|
try:
|
||||||
|
writer.write(self._render_streaming_headers(resp_hdrs, total_size))
|
||||||
|
writer.write(resp_body)
|
||||||
|
await writer.drain()
|
||||||
|
|
||||||
|
for index, (start_off, end_off) in enumerate(ranges):
|
||||||
|
tasks.append(asyncio.create_task(fetch_range(index, start_off, end_off)))
|
||||||
|
|
||||||
|
for index, (start_off, end_off) in enumerate(ranges):
|
||||||
|
await ready[index].wait()
|
||||||
|
if errors[index] is not None:
|
||||||
|
raise errors[index]
|
||||||
|
expected = end_off - start_off + 1
|
||||||
|
async with file_lock:
|
||||||
|
chunk = await asyncio.to_thread(
|
||||||
|
self._spool_read, temp_file, start_off, expected,
|
||||||
|
)
|
||||||
|
if len(chunk) != expected:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"spooled chunk {start_off}-{end_off} was truncated "
|
||||||
|
f"({len(chunk)}/{expected} B)"
|
||||||
|
)
|
||||||
|
writer.write(chunk)
|
||||||
|
await writer.drain()
|
||||||
|
delivered_chunks += 1
|
||||||
|
delivered_bytes += len(chunk)
|
||||||
|
await _write_progress(force=(index == len(ranges) - 1))
|
||||||
|
|
||||||
|
elapsed = max(0.001, time.perf_counter() - t0)
|
||||||
|
log.info(
|
||||||
|
"Parallel streaming download complete: %s",
|
||||||
|
self._progress_line(
|
||||||
|
elapsed=elapsed,
|
||||||
|
done=total_size,
|
||||||
|
total=total_size,
|
||||||
|
speed_bytes_per_sec=total_size / elapsed,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
except (ConnectionError, BrokenPipeError, TimeoutError) as exc:
|
||||||
|
log.info("Parallel download cancelled by client: %s", exc)
|
||||||
|
cancel_event.set()
|
||||||
|
return True
|
||||||
|
except Exception as exc:
|
||||||
|
self._mark_stream_download_failure(url, str(exc))
|
||||||
|
log.error("Parallel streaming download failed (%s): %s", url[:60], exc)
|
||||||
|
cancel_event.set()
|
||||||
|
try:
|
||||||
|
if not writer.is_closing():
|
||||||
|
writer.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return True
|
||||||
|
finally:
|
||||||
|
cancel_event.set()
|
||||||
|
for task in tasks:
|
||||||
|
task.cancel()
|
||||||
|
if tasks:
|
||||||
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
temp_file.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _rewrite_206_to_200(raw: bytes) -> bytes:
|
def _rewrite_206_to_200(raw: bytes) -> bytes:
|
||||||
"""Rewrite a 206 Partial Content response to 200 OK.
|
"""Rewrite a 206 Partial Content response to 200 OK.
|
||||||
@@ -1039,33 +1591,43 @@ class DomainFronter:
|
|||||||
|
|
||||||
async def _relay_with_retry(self, payload: dict) -> bytes:
|
async def _relay_with_retry(self, payload: dict) -> bytes:
|
||||||
"""Single relay with one retry on failure. Uses H2 if available."""
|
"""Single relay with one retry on failure. Uses H2 if available."""
|
||||||
|
attempts = self._retry_attempts_for_payload(payload)
|
||||||
# Fan-out: race N Apps Script instances when enabled and H2 is up.
|
# Fan-out: race N Apps Script instances when enabled and H2 is up.
|
||||||
# Cuts tail latency when one container is slow/cold. Only kicks in
|
# Cuts tail latency when one container is slow/cold. Only kicks in
|
||||||
# if multiple script IDs are configured and the H2 transport is live.
|
# if multiple script IDs are configured and the H2 transport is live.
|
||||||
if (self._parallel_relay > 1
|
if (attempts > 1
|
||||||
|
and self._parallel_relay > 1
|
||||||
and len(self._script_ids) > 1
|
and len(self._script_ids) > 1
|
||||||
and self._h2 and self._h2.is_connected):
|
and self._h2_available()):
|
||||||
try:
|
try:
|
||||||
return await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self._relay_fanout(payload), timeout=RELAY_TIMEOUT,
|
self._relay_fanout(payload), timeout=self._relay_timeout,
|
||||||
)
|
)
|
||||||
|
self._record_h2_success()
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self._record_h2_failure(e)
|
||||||
log.debug("Fan-out relay failed (%s), falling back", e)
|
log.debug("Fan-out relay failed (%s), falling back", e)
|
||||||
# fall through to single-path logic below
|
# fall through to single-path logic below
|
||||||
|
|
||||||
# Try HTTP/2 first — much faster (multiplexed, no pool checkout)
|
# Try HTTP/2 first — much faster (multiplexed, no pool checkout)
|
||||||
if self._h2 and self._h2.is_connected:
|
if self._h2_available():
|
||||||
for attempt in range(2):
|
for attempt in range(attempts):
|
||||||
try:
|
try:
|
||||||
return await asyncio.wait_for(
|
result = await asyncio.wait_for(
|
||||||
self._relay_single_h2(payload), timeout=RELAY_TIMEOUT
|
self._relay_single_h2(payload), timeout=self._relay_timeout
|
||||||
)
|
)
|
||||||
|
self._record_h2_success()
|
||||||
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt == 0:
|
self._record_h2_failure(e)
|
||||||
|
if attempt < attempts - 1:
|
||||||
log.debug("H2 relay failed (%s), reconnecting", e)
|
log.debug("H2 relay failed (%s), reconnecting", e)
|
||||||
try:
|
try:
|
||||||
await self._h2.reconnect()
|
await self._h2.reconnect()
|
||||||
except Exception:
|
self._record_h2_success()
|
||||||
|
except Exception as reconnect_exc:
|
||||||
|
self._record_h2_failure(reconnect_exc)
|
||||||
log.warning("H2 reconnect failed, falling back to H1")
|
log.warning("H2 reconnect failed, falling back to H1")
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
@@ -1073,14 +1635,15 @@ class DomainFronter:
|
|||||||
|
|
||||||
# HTTP/1.1 fallback (pool-based)
|
# HTTP/1.1 fallback (pool-based)
|
||||||
async with self._semaphore:
|
async with self._semaphore:
|
||||||
for attempt in range(2):
|
for attempt in range(attempts):
|
||||||
try:
|
try:
|
||||||
return await asyncio.wait_for(
|
return await asyncio.wait_for(
|
||||||
self._relay_single(payload), timeout=RELAY_TIMEOUT
|
self._relay_single(payload), timeout=self._relay_timeout
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt == 0:
|
if attempt < attempts - 1:
|
||||||
log.debug("Relay attempt 1 failed (%s: %s), retrying",
|
log.debug("Relay attempt %d failed (%s: %s), retrying",
|
||||||
|
attempt + 1,
|
||||||
type(e).__name__, e)
|
type(e).__name__, e)
|
||||||
await self._flush_pool()
|
await self._flush_pool()
|
||||||
else:
|
else:
|
||||||
@@ -1248,7 +1811,7 @@ class DomainFronter:
|
|||||||
path = self._exec_path(payloads[0].get("u") if payloads else None)
|
path = self._exec_path(payloads[0].get("u") if payloads else None)
|
||||||
|
|
||||||
# Try HTTP/2 first
|
# Try HTTP/2 first
|
||||||
if self._h2 and self._h2.is_connected:
|
if self._h2_available():
|
||||||
try:
|
try:
|
||||||
status, headers, body = await asyncio.wait_for(
|
status, headers, body = await asyncio.wait_for(
|
||||||
self._h2.request(
|
self._h2.request(
|
||||||
@@ -1258,8 +1821,10 @@ class DomainFronter:
|
|||||||
),
|
),
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
|
self._record_h2_success()
|
||||||
return self._parse_batch_body(body, payloads)
|
return self._parse_batch_body(body, payloads)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
self._record_h2_failure(e)
|
||||||
log.debug("H2 batch failed (%s), falling back to H1", e)
|
log.debug("H2 batch failed (%s), falling back to H1", e)
|
||||||
|
|
||||||
# HTTP/1.1 fallback
|
# HTTP/1.1 fallback
|
||||||
@@ -1383,7 +1948,13 @@ class DomainFronter:
|
|||||||
if "chunked" in transfer_encoding:
|
if "chunked" in transfer_encoding:
|
||||||
body = await self._read_chunked(reader, body)
|
body = await self._read_chunked(reader, body)
|
||||||
elif content_length:
|
elif content_length:
|
||||||
remaining = int(content_length) - len(body)
|
total = int(content_length)
|
||||||
|
if total > self._max_response_body_bytes:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Relay response exceeds configured size cap "
|
||||||
|
f"({total} > {self._max_response_body_bytes} bytes)"
|
||||||
|
)
|
||||||
|
remaining = total - len(body)
|
||||||
while remaining > 0:
|
while remaining > 0:
|
||||||
chunk = await asyncio.wait_for(
|
chunk = await asyncio.wait_for(
|
||||||
reader.read(min(remaining, 65536)), timeout=20
|
reader.read(min(remaining, 65536)), timeout=20
|
||||||
@@ -1391,6 +1962,10 @@ class DomainFronter:
|
|||||||
if not chunk:
|
if not chunk:
|
||||||
break
|
break
|
||||||
body += chunk
|
body += chunk
|
||||||
|
if len(body) > self._max_response_body_bytes:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Relay response exceeded configured size cap while reading body"
|
||||||
|
)
|
||||||
remaining -= len(chunk)
|
remaining -= len(chunk)
|
||||||
else:
|
else:
|
||||||
# No framing — short timeout read (keep-alive safe)
|
# No framing — short timeout read (keep-alive safe)
|
||||||
@@ -1400,6 +1975,10 @@ class DomainFronter:
|
|||||||
if not chunk:
|
if not chunk:
|
||||||
break
|
break
|
||||||
body += chunk
|
body += chunk
|
||||||
|
if len(body) > self._max_response_body_bytes:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Relay response exceeded configured size cap while streaming"
|
||||||
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -1407,13 +1986,17 @@ class DomainFronter:
|
|||||||
enc = headers.get("content-encoding", "")
|
enc = headers.get("content-encoding", "")
|
||||||
if enc:
|
if enc:
|
||||||
body = codec.decode(body, enc)
|
body = codec.decode(body, enc)
|
||||||
|
if len(body) > self._max_response_body_bytes:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Decoded relay response exceeded configured size cap"
|
||||||
|
)
|
||||||
|
|
||||||
return status, headers, body
|
return status, headers, body
|
||||||
|
|
||||||
async def _read_chunked(self, reader, buf=b""):
|
async def _read_chunked(self, reader, buf=b""):
|
||||||
"""Incrementally read chunked transfer-encoding."""
|
"""Incrementally read chunked transfer-encoding."""
|
||||||
result = b""
|
result = b""
|
||||||
_MAX_BODY = 200 * 1024 * 1024 # 200 MB total body cap
|
max_body = self._max_response_body_bytes
|
||||||
while True:
|
while True:
|
||||||
while b"\r\n" not in buf:
|
while b"\r\n" not in buf:
|
||||||
data = await asyncio.wait_for(reader.read(8192), timeout=20)
|
data = await asyncio.wait_for(reader.read(8192), timeout=20)
|
||||||
@@ -1433,9 +2016,11 @@ class DomainFronter:
|
|||||||
break
|
break
|
||||||
if size == 0:
|
if size == 0:
|
||||||
break
|
break
|
||||||
if size > _MAX_BODY or len(result) + size > _MAX_BODY:
|
if size > max_body or len(result) + size > max_body:
|
||||||
log.warning("Chunked body exceeds %d MB cap — truncating", _MAX_BODY // (1024 * 1024))
|
raise RuntimeError(
|
||||||
break
|
"Chunked relay response exceeded configured size cap "
|
||||||
|
f"({max_body} bytes)"
|
||||||
|
)
|
||||||
|
|
||||||
while len(buf) < size + 2:
|
while len(buf) < size + 2:
|
||||||
data = await asyncio.wait_for(reader.read(65536), timeout=20)
|
data = await asyncio.wait_for(reader.read(65536), timeout=20)
|
||||||
@@ -1479,6 +2064,13 @@ class DomainFronter:
|
|||||||
status = data.get("s", 200)
|
status = data.get("s", 200)
|
||||||
resp_headers = data.get("h", {})
|
resp_headers = data.get("h", {})
|
||||||
resp_body = base64.b64decode(data.get("b", ""))
|
resp_body = base64.b64decode(data.get("b", ""))
|
||||||
|
if len(resp_body) > self._max_response_body_bytes:
|
||||||
|
return self._error_response(
|
||||||
|
502,
|
||||||
|
"Relay response exceeds cap "
|
||||||
|
f"({self._max_response_body_bytes} bytes). "
|
||||||
|
"Increase max_response_body_bytes if your system has enough RAM.",
|
||||||
|
)
|
||||||
|
|
||||||
status_text = {200: "OK", 206: "Partial Content",
|
status_text = {200: "OK", 206: "Partial Content",
|
||||||
301: "Moved", 302: "Found", 304: "Not Modified",
|
301: "Moved", 302: "Found", 304: "Not Modified",
|
||||||
|
|||||||
+31
-16
@@ -80,6 +80,7 @@ class H2Transport:
|
|||||||
self._write_lock = asyncio.Lock()
|
self._write_lock = asyncio.Lock()
|
||||||
self._connect_lock = asyncio.Lock()
|
self._connect_lock = asyncio.Lock()
|
||||||
self._read_task: asyncio.Task | None = None
|
self._read_task: asyncio.Task | None = None
|
||||||
|
self._conn_generation = 0
|
||||||
|
|
||||||
# Per-stream tracking
|
# Per-stream tracking
|
||||||
self._streams: dict[int, _StreamState] = {}
|
self._streams: dict[int, _StreamState] = {}
|
||||||
@@ -174,26 +175,34 @@ class H2Transport:
|
|||||||
await self._flush()
|
await self._flush()
|
||||||
|
|
||||||
self._connected = True
|
self._connected = True
|
||||||
self._read_task = asyncio.create_task(self._reader_loop())
|
self._conn_generation += 1
|
||||||
|
generation = self._conn_generation
|
||||||
|
self._read_task = asyncio.create_task(self._reader_loop(generation))
|
||||||
log.info("H2 connected → %s (SNI=%s, TCP_NODELAY=on)",
|
log.info("H2 connected → %s (SNI=%s, TCP_NODELAY=on)",
|
||||||
self.connect_host, sni)
|
self.connect_host, sni)
|
||||||
|
|
||||||
async def reconnect(self):
|
async def reconnect(self):
|
||||||
"""Close current connection and re-establish."""
|
"""Close current connection and re-establish."""
|
||||||
await self._close_internal()
|
async with self._connect_lock:
|
||||||
await self._do_connect()
|
await self._close_internal()
|
||||||
|
await self._do_connect()
|
||||||
|
|
||||||
async def _close_internal(self):
|
async def _close_internal(self):
|
||||||
self._connected = False
|
self._connected = False
|
||||||
if self._read_task:
|
read_task = self._read_task
|
||||||
self._read_task.cancel()
|
self._read_task = None
|
||||||
self._read_task = None
|
if read_task:
|
||||||
|
read_task.cancel()
|
||||||
|
await asyncio.gather(read_task, return_exceptions=True)
|
||||||
if self._writer:
|
if self._writer:
|
||||||
try:
|
try:
|
||||||
self._writer.close()
|
writer = self._writer
|
||||||
|
self._writer = None
|
||||||
|
writer.close()
|
||||||
|
await writer.wait_closed()
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self._writer = None
|
self._reader = None
|
||||||
# Wake all pending streams so they can raise
|
# Wake all pending streams so they can raise
|
||||||
for state in self._streams.values():
|
for state in self._streams.values():
|
||||||
state.error = "Connection closed"
|
state.error = "Connection closed"
|
||||||
@@ -327,7 +336,7 @@ class H2Transport:
|
|||||||
|
|
||||||
# ── Background reader ─────────────────────────────────────────
|
# ── Background reader ─────────────────────────────────────────
|
||||||
|
|
||||||
async def _reader_loop(self):
|
async def _reader_loop(self, generation: int):
|
||||||
"""Background: read H2 frames, dispatch events to waiting streams."""
|
"""Background: read H2 frames, dispatch events to waiting streams."""
|
||||||
try:
|
try:
|
||||||
while self._connected:
|
while self._connected:
|
||||||
@@ -352,14 +361,20 @@ class H2Transport:
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("H2 reader error: %s", e)
|
if "application data after close notify" in str(e).lower():
|
||||||
|
log.debug("H2 reader closed after close_notify: %s", e)
|
||||||
|
else:
|
||||||
|
log.error("H2 reader error: %s", e)
|
||||||
finally:
|
finally:
|
||||||
self._connected = False
|
if generation != self._conn_generation:
|
||||||
for state in self._streams.values():
|
log.debug("H2 reader loop ended for stale generation %d", generation)
|
||||||
if not state.done.is_set():
|
else:
|
||||||
state.error = "Connection lost"
|
self._connected = False
|
||||||
state.done.set()
|
for state in self._streams.values():
|
||||||
log.info("H2 reader loop ended")
|
if not state.done.is_set():
|
||||||
|
state.error = "Connection lost"
|
||||||
|
state.done.set()
|
||||||
|
log.info("H2 reader loop ended")
|
||||||
|
|
||||||
def _dispatch(self, event):
|
def _dispatch(self, event):
|
||||||
"""Route a single h2 event to its stream."""
|
"""Route a single h2 event to its stream."""
|
||||||
|
|||||||
+199
-114
@@ -65,6 +65,23 @@ def _parse_content_length(header_block: bytes) -> int:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _has_unsupported_transfer_encoding(header_block: bytes) -> bool:
|
||||||
|
"""True when the request uses Transfer-Encoding, which we don't stream."""
|
||||||
|
for raw_line in header_block.split(b"\r\n"):
|
||||||
|
name, sep, value = raw_line.partition(b":")
|
||||||
|
if not sep:
|
||||||
|
continue
|
||||||
|
if name.strip().lower() != b"transfer-encoding":
|
||||||
|
continue
|
||||||
|
encodings = [
|
||||||
|
token.strip().lower()
|
||||||
|
for token in value.decode(errors="replace").split(",")
|
||||||
|
if token.strip()
|
||||||
|
]
|
||||||
|
return any(token != "identity" for token in encodings)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class ResponseCache:
|
class ResponseCache:
|
||||||
"""Simple LRU response cache — avoids repeated relay calls."""
|
"""Simple LRU response cache — avoids repeated relay calls."""
|
||||||
|
|
||||||
@@ -147,6 +164,14 @@ class ProxyServer:
|
|||||||
_GOOGLE_DIRECT_ALLOW_EXACT = GOOGLE_DIRECT_ALLOW_EXACT
|
_GOOGLE_DIRECT_ALLOW_EXACT = GOOGLE_DIRECT_ALLOW_EXACT
|
||||||
_GOOGLE_DIRECT_ALLOW_SUFFIXES = GOOGLE_DIRECT_ALLOW_SUFFIXES
|
_GOOGLE_DIRECT_ALLOW_SUFFIXES = GOOGLE_DIRECT_ALLOW_SUFFIXES
|
||||||
_TRACE_HOST_SUFFIXES = TRACE_HOST_SUFFIXES
|
_TRACE_HOST_SUFFIXES = TRACE_HOST_SUFFIXES
|
||||||
|
_DOWNLOAD_DEFAULT_EXTS = tuple(sorted(LARGE_FILE_EXTS))
|
||||||
|
_DOWNLOAD_ACCEPT_MARKERS = (
|
||||||
|
"application/octet-stream",
|
||||||
|
"application/zip",
|
||||||
|
"application/x-bittorrent",
|
||||||
|
"video/",
|
||||||
|
"audio/",
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, config: dict):
|
def __init__(self, config: dict):
|
||||||
self.host = config.get("listen_host", "127.0.0.1")
|
self.host = config.get("listen_host", "127.0.0.1")
|
||||||
@@ -159,6 +184,30 @@ class ProxyServer:
|
|||||||
self._cache = ResponseCache(max_mb=CACHE_MAX_MB)
|
self._cache = ResponseCache(max_mb=CACHE_MAX_MB)
|
||||||
self._direct_fail_until: dict[str, float] = {}
|
self._direct_fail_until: dict[str, float] = {}
|
||||||
self._servers: list[asyncio.base_events.Server] = []
|
self._servers: list[asyncio.base_events.Server] = []
|
||||||
|
self._client_tasks: set[asyncio.Task] = set()
|
||||||
|
self._tcp_connect_timeout = self._cfg_float(
|
||||||
|
config, "tcp_connect_timeout", TCP_CONNECT_TIMEOUT, minimum=1.0,
|
||||||
|
)
|
||||||
|
self._download_min_size = self._cfg_int(
|
||||||
|
config, "chunked_download_min_size", 5 * 1024 * 1024, minimum=0,
|
||||||
|
)
|
||||||
|
self._download_chunk_size = self._cfg_int(
|
||||||
|
config, "chunked_download_chunk_size", 512 * 1024, minimum=64 * 1024,
|
||||||
|
)
|
||||||
|
self._download_max_parallel = self._cfg_int(
|
||||||
|
config, "chunked_download_max_parallel", 8, minimum=1,
|
||||||
|
)
|
||||||
|
self._download_max_chunks = self._cfg_int(
|
||||||
|
config, "chunked_download_max_chunks", 256, minimum=1,
|
||||||
|
)
|
||||||
|
self._download_extensions, self._download_any_extension = (
|
||||||
|
self._normalize_download_extensions(
|
||||||
|
config.get(
|
||||||
|
"chunked_download_extensions",
|
||||||
|
list(self._DOWNLOAD_DEFAULT_EXTS),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# hosts override — DNS fake-map: domain/suffix → IP
|
# hosts override — DNS fake-map: domain/suffix → IP
|
||||||
# Checked before any real DNS lookup; supports exact and suffix matching.
|
# Checked before any real DNS lookup; supports exact and suffix matching.
|
||||||
@@ -198,6 +247,55 @@ class ProxyServer:
|
|||||||
|
|
||||||
# ── Host-policy helpers ───────────────────────────────────────
|
# ── Host-policy helpers ───────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cfg_int(config: dict, key: str, default: int, *, minimum: int = 1) -> int:
|
||||||
|
try:
|
||||||
|
value = int(config.get(key, default))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
value = default
|
||||||
|
return max(minimum, value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _cfg_float(config: dict, key: str, default: float,
|
||||||
|
*, minimum: float = 0.1) -> float:
|
||||||
|
try:
|
||||||
|
value = float(config.get(key, default))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
value = default
|
||||||
|
return max(minimum, value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _normalize_download_extensions(cls, raw) -> tuple[tuple[str, ...], bool]:
|
||||||
|
values = raw if isinstance(raw, (list, tuple)) else cls._DOWNLOAD_DEFAULT_EXTS
|
||||||
|
normalized: list[str] = []
|
||||||
|
any_extension = False
|
||||||
|
seen: set[str] = set()
|
||||||
|
for item in values:
|
||||||
|
ext = str(item).strip().lower()
|
||||||
|
if not ext:
|
||||||
|
continue
|
||||||
|
if ext in {"*", ".*"}:
|
||||||
|
any_extension = True
|
||||||
|
continue
|
||||||
|
if not ext.startswith("."):
|
||||||
|
ext = "." + ext
|
||||||
|
if ext not in seen:
|
||||||
|
seen.add(ext)
|
||||||
|
normalized.append(ext)
|
||||||
|
if not normalized and not any_extension:
|
||||||
|
normalized = list(cls._DOWNLOAD_DEFAULT_EXTS)
|
||||||
|
return tuple(normalized), any_extension
|
||||||
|
|
||||||
|
def _track_current_task(self) -> asyncio.Task | None:
|
||||||
|
task = asyncio.current_task()
|
||||||
|
if task is not None:
|
||||||
|
self._client_tasks.add(task)
|
||||||
|
return task
|
||||||
|
|
||||||
|
def _untrack_task(self, task: asyncio.Task | None) -> None:
|
||||||
|
if task is not None:
|
||||||
|
self._client_tasks.discard(task)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_host_rules(raw) -> tuple[set[str], tuple[str, ...]]:
|
def _load_host_rules(raw) -> tuple[set[str], tuple[str, ...]]:
|
||||||
"""Accept a list of host strings; return (exact_set, suffix_tuple).
|
"""Accept a list of host strings; return (exact_set, suffix_tuple).
|
||||||
@@ -352,15 +450,18 @@ class ProxyServer:
|
|||||||
self.socks_host, self.socks_port,
|
self.socks_host, self.socks_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
async with http_srv:
|
try:
|
||||||
if socks_srv:
|
async with http_srv:
|
||||||
async with socks_srv:
|
if socks_srv:
|
||||||
await asyncio.gather(
|
async with socks_srv:
|
||||||
http_srv.serve_forever(),
|
await asyncio.gather(
|
||||||
socks_srv.serve_forever(),
|
http_srv.serve_forever(),
|
||||||
)
|
socks_srv.serve_forever(),
|
||||||
else:
|
)
|
||||||
await http_srv.serve_forever()
|
else:
|
||||||
|
await http_srv.serve_forever()
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
raise
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
"""Shut down all listeners and release relay resources."""
|
"""Shut down all listeners and release relay resources."""
|
||||||
@@ -375,6 +476,15 @@ class ProxyServer:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
self._servers = []
|
self._servers = []
|
||||||
|
|
||||||
|
current = asyncio.current_task()
|
||||||
|
client_tasks = [task for task in self._client_tasks if task is not current]
|
||||||
|
for task in client_tasks:
|
||||||
|
task.cancel()
|
||||||
|
if client_tasks:
|
||||||
|
await asyncio.gather(*client_tasks, return_exceptions=True)
|
||||||
|
self._client_tasks.clear()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.fronter.close()
|
await self.fronter.close()
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -384,6 +494,7 @@ class ProxyServer:
|
|||||||
|
|
||||||
async def _on_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
async def _on_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||||
addr = writer.get_extra_info("peername")
|
addr = writer.get_extra_info("peername")
|
||||||
|
task = self._track_current_task()
|
||||||
try:
|
try:
|
||||||
first_line = await asyncio.wait_for(reader.readline(), timeout=30)
|
first_line = await asyncio.wait_for(reader.readline(), timeout=30)
|
||||||
if not first_line:
|
if not first_line:
|
||||||
@@ -400,6 +511,16 @@ class ProxyServer:
|
|||||||
if line in (b"\r\n", b"\n", b""):
|
if line in (b"\r\n", b"\n", b""):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
if _has_unsupported_transfer_encoding(header_block):
|
||||||
|
log.warning("Unsupported Transfer-Encoding on client request")
|
||||||
|
writer.write(
|
||||||
|
b"HTTP/1.1 501 Not Implemented\r\n"
|
||||||
|
b"Connection: close\r\n"
|
||||||
|
b"Content-Length: 0\r\n\r\n"
|
||||||
|
)
|
||||||
|
await writer.drain()
|
||||||
|
return
|
||||||
|
|
||||||
request_line = first_line.decode(errors="replace").strip()
|
request_line = first_line.decode(errors="replace").strip()
|
||||||
parts = request_line.split(" ", 2)
|
parts = request_line.split(" ", 2)
|
||||||
if len(parts) < 2:
|
if len(parts) < 2:
|
||||||
@@ -412,11 +533,14 @@ class ProxyServer:
|
|||||||
else:
|
else:
|
||||||
await self._do_http(header_block, reader, writer)
|
await self._do_http(header_block, reader, writer)
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
log.debug("Timeout: %s", addr)
|
log.debug("Timeout: %s", addr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("Error (%s): %s", addr, e)
|
log.error("Error (%s): %s", addr, e)
|
||||||
finally:
|
finally:
|
||||||
|
self._untrack_task(task)
|
||||||
try:
|
try:
|
||||||
writer.close()
|
writer.close()
|
||||||
await writer.wait_closed()
|
await writer.wait_closed()
|
||||||
@@ -426,6 +550,7 @@ class ProxyServer:
|
|||||||
async def _on_socks_client(self, reader: asyncio.StreamReader,
|
async def _on_socks_client(self, reader: asyncio.StreamReader,
|
||||||
writer: asyncio.StreamWriter):
|
writer: asyncio.StreamWriter):
|
||||||
addr = writer.get_extra_info("peername")
|
addr = writer.get_extra_info("peername")
|
||||||
|
task = self._track_current_task()
|
||||||
try:
|
try:
|
||||||
header = await asyncio.wait_for(reader.readexactly(2), timeout=15)
|
header = await asyncio.wait_for(reader.readexactly(2), timeout=15)
|
||||||
ver, nmethods = header[0], header[1]
|
ver, nmethods = header[0], header[1]
|
||||||
@@ -475,11 +600,14 @@ class ProxyServer:
|
|||||||
|
|
||||||
except asyncio.IncompleteReadError:
|
except asyncio.IncompleteReadError:
|
||||||
pass
|
pass
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
log.debug("SOCKS5 timeout: %s", addr)
|
log.debug("SOCKS5 timeout: %s", addr)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("SOCKS5 error (%s): %s", addr, e)
|
log.error("SOCKS5 error (%s): %s", addr, e)
|
||||||
finally:
|
finally:
|
||||||
|
self._untrack_task(task)
|
||||||
try:
|
try:
|
||||||
writer.close()
|
writer.close()
|
||||||
await writer.wait_closed()
|
await writer.wait_closed()
|
||||||
@@ -783,7 +911,9 @@ class ProxyServer:
|
|||||||
"""
|
"""
|
||||||
target_ip = connect_ip or host
|
target_ip = connect_ip or host
|
||||||
try:
|
try:
|
||||||
r_remote, w_remote = await self._open_tcp_connection(target_ip, port, timeout=10)
|
r_remote, w_remote = await self._open_tcp_connection(
|
||||||
|
target_ip, port, timeout=self._tcp_connect_timeout,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("Direct tunnel connect failed (%s via %s): %s",
|
log.error("Direct tunnel connect failed (%s via %s): %s",
|
||||||
host, target_ip, e)
|
host, target_ip, e)
|
||||||
@@ -858,7 +988,7 @@ class ProxyServer:
|
|||||||
ssl=ssl_ctx_client,
|
ssl=ssl_ctx_client,
|
||||||
server_hostname=sni_out,
|
server_hostname=sni_out,
|
||||||
),
|
),
|
||||||
timeout=10,
|
timeout=self._tcp_connect_timeout,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.error("SNI-rewrite outbound connect failed (%s via %s): %s",
|
log.error("SNI-rewrite outbound connect failed (%s via %s): %s",
|
||||||
@@ -969,6 +1099,15 @@ class ProxyServer:
|
|||||||
|
|
||||||
# Read body
|
# Read body
|
||||||
body = b""
|
body = b""
|
||||||
|
if _has_unsupported_transfer_encoding(header_block):
|
||||||
|
log.warning("Unsupported Transfer-Encoding → %s:%d", host, port)
|
||||||
|
writer.write(
|
||||||
|
b"HTTP/1.1 501 Not Implemented\r\n"
|
||||||
|
b"Connection: close\r\n"
|
||||||
|
b"Content-Length: 0\r\n\r\n"
|
||||||
|
)
|
||||||
|
await writer.drain()
|
||||||
|
break
|
||||||
length = _parse_content_length(header_block)
|
length = _parse_content_length(header_block)
|
||||||
if length > MAX_REQUEST_BODY_BYTES:
|
if length > MAX_REQUEST_BODY_BYTES:
|
||||||
raise ValueError(f"Request body too large: {length} bytes")
|
raise ValueError(f"Request body too large: {length} bytes")
|
||||||
@@ -1008,25 +1147,7 @@ class ProxyServer:
|
|||||||
|
|
||||||
log.info("MITM → %s %s", method, url)
|
log.info("MITM → %s %s", method, url)
|
||||||
|
|
||||||
# ── CORS: extract relevant request headers ────────────────────
|
if await self._maybe_stream_download(method, url, headers, body, writer):
|
||||||
origin = next(
|
|
||||||
(v for k, v in headers.items() if k.lower() == "origin"), ""
|
|
||||||
)
|
|
||||||
acr_method = next(
|
|
||||||
(v for k, v in headers.items()
|
|
||||||
if k.lower() == "access-control-request-method"), ""
|
|
||||||
)
|
|
||||||
acr_headers = next(
|
|
||||||
(v for k, v in headers.items()
|
|
||||||
if k.lower() == "access-control-request-headers"), ""
|
|
||||||
)
|
|
||||||
|
|
||||||
# CORS preflight — respond directly; UrlFetchApp doesn't
|
|
||||||
# support OPTIONS so forwarding it would always fail.
|
|
||||||
if method.upper() == "OPTIONS" and acr_method:
|
|
||||||
log.debug("CORS preflight → %s (responding locally)", url[:60])
|
|
||||||
writer.write(self._cors_preflight_response(origin, acr_method, acr_headers))
|
|
||||||
await writer.drain()
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Check local cache first (GET only)
|
# Check local cache first (GET only)
|
||||||
@@ -1057,11 +1178,6 @@ class ProxyServer:
|
|||||||
self._cache.put(url, response, ttl)
|
self._cache.put(url, response, ttl)
|
||||||
log.debug("Cached (%ds): %s", ttl, url[:60])
|
log.debug("Cached (%ds): %s", ttl, url[:60])
|
||||||
|
|
||||||
# Inject permissive CORS headers whenever the browser
|
|
||||||
# sent an Origin (cross-origin XHR / fetch).
|
|
||||||
if origin and response:
|
|
||||||
response = self._inject_cors_headers(response, origin)
|
|
||||||
|
|
||||||
self._log_response_summary(url, response)
|
self._log_response_summary(url, response)
|
||||||
|
|
||||||
writer.write(response)
|
writer.write(response)
|
||||||
@@ -1077,64 +1193,6 @@ class ProxyServer:
|
|||||||
log.error("MITM handler error (%s): %s", host, e)
|
log.error("MITM handler error (%s): %s", host, e)
|
||||||
break
|
break
|
||||||
|
|
||||||
# ── CORS helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _cors_preflight_response(origin: str, acr_method: str, acr_headers: str) -> bytes:
|
|
||||||
"""Return a 204 No Content response that satisfies a CORS preflight."""
|
|
||||||
allow_origin = origin or "*"
|
|
||||||
allow_methods = (
|
|
||||||
f"{acr_method}, GET, POST, PUT, DELETE, PATCH, OPTIONS"
|
|
||||||
if acr_method else
|
|
||||||
"GET, POST, PUT, DELETE, PATCH, OPTIONS"
|
|
||||||
)
|
|
||||||
allow_headers = acr_headers or "*"
|
|
||||||
return (
|
|
||||||
"HTTP/1.1 204 No Content\r\n"
|
|
||||||
f"Access-Control-Allow-Origin: {allow_origin}\r\n"
|
|
||||||
f"Access-Control-Allow-Methods: {allow_methods}\r\n"
|
|
||||||
f"Access-Control-Allow-Headers: {allow_headers}\r\n"
|
|
||||||
"Access-Control-Allow-Credentials: true\r\n"
|
|
||||||
"Access-Control-Max-Age: 86400\r\n"
|
|
||||||
"Vary: Origin\r\n"
|
|
||||||
"Content-Length: 0\r\n"
|
|
||||||
"\r\n"
|
|
||||||
).encode()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _inject_cors_headers(response: bytes, origin: str) -> bytes:
|
|
||||||
"""Inject CORS headers only if the upstream response lacks them.
|
|
||||||
|
|
||||||
We must NOT overwrite the origin server's CORS headers: sites like
|
|
||||||
x.com return carefully-scoped Access-Control-Allow-Headers that list
|
|
||||||
specific custom headers (e.g. x-csrf-token). Replacing them with
|
|
||||||
wildcards together with Allow-Credentials: true makes browsers
|
|
||||||
reject the response (per the Fetch spec, "*" is literal when
|
|
||||||
credentials are included), which the site then blames on privacy
|
|
||||||
extensions. So we only fill in what the server omitted.
|
|
||||||
"""
|
|
||||||
sep = b"\r\n\r\n"
|
|
||||||
if sep not in response:
|
|
||||||
return response
|
|
||||||
header_section, body = response.split(sep, 1)
|
|
||||||
lines = header_section.decode(errors="replace").split("\r\n")
|
|
||||||
|
|
||||||
existing = {ln.split(":", 1)[0].strip().lower()
|
|
||||||
for ln in lines if ":" in ln}
|
|
||||||
|
|
||||||
# If the upstream already handled CORS, leave it completely alone.
|
|
||||||
if "access-control-allow-origin" in existing:
|
|
||||||
return response
|
|
||||||
|
|
||||||
# Otherwise inject a minimal, credential-safe set (no wildcards,
|
|
||||||
# since wildcards combined with credentials are invalid).
|
|
||||||
allow_origin = origin or "*"
|
|
||||||
additions = [f"Access-Control-Allow-Origin: {allow_origin}"]
|
|
||||||
if allow_origin != "*":
|
|
||||||
additions.append("Access-Control-Allow-Credentials: true")
|
|
||||||
additions.append("Vary: Origin")
|
|
||||||
return ("\r\n".join(lines + additions) + "\r\n\r\n").encode() + body
|
|
||||||
|
|
||||||
async def _relay_smart(self, method, url, headers, body):
|
async def _relay_smart(self, method, url, headers, body):
|
||||||
"""Choose optimal relay strategy based on request type.
|
"""Choose optimal relay strategy based on request type.
|
||||||
|
|
||||||
@@ -1156,22 +1214,67 @@ class ProxyServer:
|
|||||||
# Only probe with Range when the URL looks like a big file.
|
# Only probe with Range when the URL looks like a big file.
|
||||||
if self._is_likely_download(url, headers):
|
if self._is_likely_download(url, headers):
|
||||||
return await self.fronter.relay_parallel(
|
return await self.fronter.relay_parallel(
|
||||||
method, url, headers, body
|
method,
|
||||||
|
url,
|
||||||
|
headers,
|
||||||
|
body,
|
||||||
|
chunk_size=self._download_chunk_size,
|
||||||
|
max_parallel=self._download_max_parallel,
|
||||||
|
max_chunks=self._download_max_chunks,
|
||||||
|
min_size=self._download_min_size,
|
||||||
)
|
)
|
||||||
return await self.fronter.relay(method, url, headers, body)
|
return await self.fronter.relay(method, url, headers, body)
|
||||||
|
|
||||||
def _is_likely_download(self, url: str, headers: dict) -> bool:
|
def _is_likely_download(self, url: str, headers: dict) -> bool:
|
||||||
"""Heuristic: is this URL likely a large file download?"""
|
"""Heuristic: is this URL likely a large file download?"""
|
||||||
path = url.split("?")[0].lower()
|
path = url.split("?")[0].lower()
|
||||||
for ext in LARGE_FILE_EXTS:
|
if self._download_any_extension:
|
||||||
|
return True
|
||||||
|
for ext in self._download_extensions:
|
||||||
if path.endswith(ext):
|
if path.endswith(ext):
|
||||||
return True
|
return True
|
||||||
|
accept = self._header_value(headers, "accept").lower()
|
||||||
|
if any(marker in accept for marker in self._DOWNLOAD_ACCEPT_MARKERS):
|
||||||
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def _maybe_stream_download(self, method: str, url: str,
|
||||||
|
headers: dict | None, body: bytes,
|
||||||
|
writer) -> bool:
|
||||||
|
if method.upper() != "GET" or body:
|
||||||
|
return False
|
||||||
|
if headers:
|
||||||
|
for key in headers:
|
||||||
|
if key.lower() == "range":
|
||||||
|
return False
|
||||||
|
effective_headers = headers or {}
|
||||||
|
if not self._is_likely_download(url, effective_headers):
|
||||||
|
return False
|
||||||
|
if not self.fronter.stream_download_allowed(url):
|
||||||
|
return False
|
||||||
|
return await self.fronter.stream_parallel_download(
|
||||||
|
url,
|
||||||
|
effective_headers,
|
||||||
|
writer,
|
||||||
|
chunk_size=self._download_chunk_size,
|
||||||
|
max_parallel=self._download_max_parallel,
|
||||||
|
max_chunks=self._download_max_chunks,
|
||||||
|
min_size=self._download_min_size,
|
||||||
|
)
|
||||||
|
|
||||||
# ── Plain HTTP forwarding ─────────────────────────────────────
|
# ── Plain HTTP forwarding ─────────────────────────────────────
|
||||||
|
|
||||||
async def _do_http(self, header_block: bytes, reader, writer):
|
async def _do_http(self, header_block: bytes, reader, writer):
|
||||||
body = b""
|
body = b""
|
||||||
|
if _has_unsupported_transfer_encoding(header_block):
|
||||||
|
log.warning("Unsupported Transfer-Encoding on plain HTTP request")
|
||||||
|
writer.write(
|
||||||
|
b"HTTP/1.1 501 Not Implemented\r\n"
|
||||||
|
b"Connection: close\r\n"
|
||||||
|
b"Content-Length: 0\r\n\r\n"
|
||||||
|
)
|
||||||
|
await writer.drain()
|
||||||
|
return
|
||||||
length = _parse_content_length(header_block)
|
length = _parse_content_length(header_block)
|
||||||
if length > MAX_REQUEST_BODY_BYTES:
|
if length > MAX_REQUEST_BODY_BYTES:
|
||||||
writer.write(b"HTTP/1.1 413 Content Too Large\r\n\r\n")
|
writer.write(b"HTTP/1.1 413 Content Too Large\r\n\r\n")
|
||||||
@@ -1194,22 +1297,7 @@ class ProxyServer:
|
|||||||
k, v = raw_line.decode(errors="replace").split(":", 1)
|
k, v = raw_line.decode(errors="replace").split(":", 1)
|
||||||
headers[k.strip()] = v.strip()
|
headers[k.strip()] = v.strip()
|
||||||
|
|
||||||
# ── CORS preflight over plain HTTP ────────────────────────────
|
if await self._maybe_stream_download(method, url, headers, body, writer):
|
||||||
origin = next(
|
|
||||||
(v for k, v in headers.items() if k.lower() == "origin"), ""
|
|
||||||
)
|
|
||||||
acr_method = next(
|
|
||||||
(v for k, v in headers.items()
|
|
||||||
if k.lower() == "access-control-request-method"), ""
|
|
||||||
)
|
|
||||||
acr_headers_val = next(
|
|
||||||
(v for k, v in headers.items()
|
|
||||||
if k.lower() == "access-control-request-headers"), ""
|
|
||||||
)
|
|
||||||
if method.upper() == "OPTIONS" and acr_method:
|
|
||||||
log.debug("CORS preflight (HTTP) → %s (responding locally)", url[:60])
|
|
||||||
writer.write(self._cors_preflight_response(origin, acr_method, acr_headers_val))
|
|
||||||
await writer.drain()
|
|
||||||
return
|
return
|
||||||
|
|
||||||
# Cache check for GET
|
# Cache check for GET
|
||||||
@@ -1227,9 +1315,6 @@ class ProxyServer:
|
|||||||
if ttl > 0:
|
if ttl > 0:
|
||||||
self._cache.put(url, response, ttl)
|
self._cache.put(url, response, ttl)
|
||||||
|
|
||||||
# Inject CORS headers for cross-origin requests
|
|
||||||
if origin and response:
|
|
||||||
response = self._inject_cors_headers(response, origin)
|
|
||||||
self._log_response_summary(url, response)
|
self._log_response_summary(url, response)
|
||||||
|
|
||||||
writer.write(response)
|
writer.write(response)
|
||||||
|
|||||||
Reference in New Issue
Block a user