mirror of
https://github.com/masterking32/MasterHttpRelayVPN.git
synced 2026-05-17 21:24:37 +03:00
2849 lines
117 KiB
Python
2849 lines
117 KiB
Python
"""
|
||
Apps Script relay engine.
|
||
|
||
Domain fronting via Google Apps Script: POST JSON to script.google.com
|
||
(fronted through www.google.com). Apps Script fetches the target URL and
|
||
returns the response.
|
||
|
||
relay() — JSON-based HTTP relay through Apps Script
|
||
"""
|
||
|
||
import asyncio
|
||
import base64
|
||
import codecs
|
||
import hashlib
|
||
import json
|
||
import logging
|
||
import re
|
||
import socket
|
||
import ssl
|
||
import statistics
|
||
import tempfile
|
||
import time
|
||
from urllib.parse import urlparse
|
||
|
||
try:
|
||
import certifi
|
||
except Exception: # optional dependency fallback
|
||
certifi = None
|
||
|
||
from core import codec
|
||
from core.constants import (
|
||
BATCH_MAX,
|
||
BATCH_WINDOW_MACRO,
|
||
BATCH_WINDOW_MICRO,
|
||
CONN_TTL,
|
||
MAX_RESPONSE_BODY_BYTES,
|
||
POOL_MAX,
|
||
POOL_MIN_IDLE,
|
||
RELAY_TIMEOUT,
|
||
SCRIPT_BLACKLIST_TTL,
|
||
SEMAPHORE_MAX,
|
||
STATEFUL_HEADER_NAMES,
|
||
STATIC_EXTS,
|
||
STATS_LOG_INTERVAL,
|
||
STATS_LOG_TOP_N,
|
||
TLS_CONNECT_TIMEOUT,
|
||
WARM_POOL_COUNT,
|
||
)
|
||
from .fronting_support import (
|
||
HostStat,
|
||
build_sni_pool,
|
||
format_bytes_human,
|
||
format_elapsed_short,
|
||
parse_content_range,
|
||
progress_line,
|
||
render_progress_bar,
|
||
spool_read,
|
||
spool_write,
|
||
validate_range_response,
|
||
)
|
||
from .relay_response import (
|
||
classify_relay_error,
|
||
error_response,
|
||
extract_apps_script_user_html,
|
||
load_relay_json,
|
||
parse_relay_json,
|
||
parse_relay_response,
|
||
split_raw_response,
|
||
split_set_cookie,
|
||
)
|
||
from .http_reader import read_http_response
|
||
|
||
log = logging.getLogger("Fronter")
|
||
|
||
|
||
def _mask_sid(sid: str) -> str:
|
||
"""Return a safe display form of an Apps Script deployment ID.
|
||
|
||
Full deployment IDs look like ``AKfycbwLd8Ca2BIsMWs5uN3x7...``
|
||
and should never appear in log files or screenshots that users might
|
||
share in issue reports. Show only the first 6 and last 4 characters
|
||
so it's identifiable but not usable to hijack the deployment:
|
||
|
||
AKfycb…5dGE
|
||
"""
|
||
if not sid or len(sid) <= 12:
|
||
return sid or "(none)"
|
||
return f"{sid[:6]}\u2026{sid[-4:]}"
|
||
|
||
|
||
class DomainFronter:
|
||
_STATIC_EXTS = STATIC_EXTS
|
||
_H2_FAILURE_COOLDOWN = 15.0 # reduced: DPI token bucket refills in ~8-10s
|
||
_H2_FAILURE_THRESHOLD = 5 # raised: needs genuine consecutive failures
|
||
# URL extensions that almost always produce large responses (fonts, images,
|
||
# media). These are isolated into their own H2 sub-batch so a 400 kB font
|
||
# doesn't block a 2 kB JS file waiting for the same Apps Script response.
|
||
_HEAVY_EXTENSIONS = frozenset({
|
||
"woff2", "woff", "ttf", "eot", "otf",
|
||
"jpg", "jpeg", "png", "gif", "webp", "avif", "ico",
|
||
"mp4", "mp3", "wav", "webm", "ogg", "flac",
|
||
})
|
||
_DOWNLOAD_STREAM_COOLDOWN = 300.0
|
||
_COALESCE_VARY_HEADERS = (
|
||
"accept",
|
||
"accept-language",
|
||
"user-agent",
|
||
"sec-fetch-dest",
|
||
"sec-fetch-mode",
|
||
"sec-fetch-site",
|
||
)
|
||
_SAFE_RETRY_METHODS = {"GET", "HEAD", "OPTIONS"}
|
||
_APPS_SCRIPT_DEFAULT_LANG = "en"
|
||
|
||
def __init__(self, config: dict):
|
||
self.connect_host = config.get("google_ip", "216.239.38.120")
|
||
self.sni_host = config.get("front_domain", "www.google.com")
|
||
# SNI rotation pool — rotated per new outbound TLS connection so
|
||
# DPI systems can't fingerprint traffic as "always one SNI".
|
||
self._sni_hosts = build_sni_pool(
|
||
self.sni_host, config.get("front_domains"),
|
||
)
|
||
self._sni_idx = 0
|
||
self._sni_probe_task: asyncio.Task | None = None
|
||
self.http_host = "script.google.com"
|
||
# Multi-script round-robin for higher throughput
|
||
script = config.get("script_ids") or config.get("script_id")
|
||
self._script_ids = script if isinstance(script, list) else [script]
|
||
self._script_idx = 0
|
||
self.script_id = self._script_ids[0] # backward compat / logging
|
||
self._dev_available = False # True if /dev endpoint works (no redirect, ~400ms faster)
|
||
self._apps_script_lang = str(
|
||
config.get("apps_script_lang", self._APPS_SCRIPT_DEFAULT_LANG)
|
||
).strip().lower() or self._APPS_SCRIPT_DEFAULT_LANG
|
||
|
||
# Simple execution monitor: log total consumed Apps Script executions.
|
||
self._execution_report_interval = 5.0
|
||
self._exec_total = 0
|
||
self._execution_task: asyncio.Task | None = None
|
||
|
||
# Fan-out parallel relay: fire N Apps Script instances concurrently,
|
||
# keep the first successful response, cancel the rest. Script IDs
|
||
# that fail or time out get blacklisted for SCRIPT_BLACKLIST_TTL so
|
||
# a single slow container stops poisoning tail latency.
|
||
try:
|
||
self._parallel_relay = int(config.get("parallel_relay", 1))
|
||
except (TypeError, ValueError):
|
||
self._parallel_relay = 1
|
||
self._parallel_relay = max(1, min(self._parallel_relay,
|
||
len(self._script_ids)))
|
||
self._sid_blacklist: dict[str, float] = {}
|
||
self._blacklist_ttl = SCRIPT_BLACKLIST_TTL
|
||
|
||
# Per-host stats (requests, cache hits, bytes, cumulative latency).
|
||
self._per_site: dict[str, HostStat] = {}
|
||
self._stats_task: asyncio.Task | None = None
|
||
|
||
self.auth_key = config.get("auth_key", "")
|
||
self.verify_ssl = config.get("verify_ssl", True)
|
||
# Build the SSLContext once so every TLS connection open reuses it
|
||
# instead of rebuilding the CA bundle and context on each dial.
|
||
self._ssl_context: ssl.SSLContext = self._build_ssl_ctx(self.verify_ssl)
|
||
self._relay_timeout = self._cfg_float(
|
||
config, "relay_timeout", RELAY_TIMEOUT, minimum=1.0,
|
||
)
|
||
self._tls_connect_timeout = self._cfg_float(
|
||
config, "tls_connect_timeout", TLS_CONNECT_TIMEOUT, minimum=1.0,
|
||
)
|
||
self._sni_probe_timeout = min(self._tls_connect_timeout, 4.0)
|
||
# Keep response cap as a code-level constant to avoid exposing an
|
||
# advanced memory-safety knob in end-user config.
|
||
self._max_response_body_bytes = MAX_RESPONSE_BODY_BYTES
|
||
|
||
# Connection pool — TTL-based, pre-warmed, with concurrency control
|
||
self._pool: list[tuple[asyncio.StreamReader, asyncio.StreamWriter, float]] = []
|
||
self._pool_lock = asyncio.Lock()
|
||
self._pool_max = POOL_MAX
|
||
self._conn_ttl = CONN_TTL
|
||
self._semaphore = asyncio.Semaphore(SEMAPHORE_MAX)
|
||
self._warmed = False
|
||
self._refilling = False
|
||
self._pool_min_idle = POOL_MIN_IDLE
|
||
# H1 is fallback-only when H2 is active. We don't know yet whether
|
||
# the H2 pool will succeed (set later in __init__), so default to the
|
||
# full warm count and let the H2 init below shrink it if applicable.
|
||
self._warm_count = WARM_POOL_COUNT
|
||
self._maintenance_task: asyncio.Task | None = None
|
||
self._keepalive_task: asyncio.Task | None = None
|
||
self._warm_task: asyncio.Task | None = None
|
||
self._bg_tasks: set[asyncio.Task] = set()
|
||
# Set by _do_warm() when the initial TLS connection batch is open.
|
||
# The very first relay() call awaits this (with a short timeout) so it
|
||
# never dispatches a request onto a completely cold pool.
|
||
self._pool_ready = asyncio.Event()
|
||
|
||
# Batch collector for grouping concurrent relay() calls
|
||
self._batch_lock = asyncio.Lock()
|
||
self._batch_pending: list[tuple[dict, asyncio.Future]] = []
|
||
self._batch_task: asyncio.Task | None = None
|
||
self._batch_window_micro = float(config.get("batch_window_micro", BATCH_WINDOW_MICRO))
|
||
self._batch_window_macro = float(config.get("batch_window_macro", BATCH_WINDOW_MACRO))
|
||
self._batch_max = int(config.get("batch_max", BATCH_MAX))
|
||
# enable_batch=false → each request gets its own H2 stream → N×2 KiB/s
|
||
# aggregate throughput instead of all requests sharing one stream.
|
||
# Recommended when DPI does per-stream rate limiting (e.g. Iran).
|
||
self._batch_permanent_disable: bool = not bool(config.get("enable_batch", True))
|
||
self._batch_enabled = not self._batch_permanent_disable
|
||
self._batch_disabled_at = 0.0
|
||
self._batch_cooldown = 60
|
||
# enable_sub_batch=false → all batches are sent as a single Apps Script
|
||
# call regardless of how many H2 connections are live. Saves quota at
|
||
# the cost of parallel DPI bypass (each connection no longer gets its
|
||
# own token bucket). Useful when quota is the binding constraint.
|
||
self._sub_batch_enabled: bool = bool(config.get("enable_sub_batch", True))
|
||
|
||
# Request coalescing — dedup concurrent identical GETs
|
||
self._coalesce: dict[str, list[asyncio.Future]] = {}
|
||
self._h2_failure_streak = 0
|
||
self._h2_disabled_until = 0.0
|
||
# When the H2 reader loop ends, EVERY in-flight stream raises a
|
||
# ConnectionError simultaneously. Without de-duping by connection
|
||
# generation, a single drop with 5+ in-flight streams trips the
|
||
# disable threshold and forces a 15s H1 fallback for no reason.
|
||
self._h2_last_failure_gen: int = -1
|
||
self._stream_download_disabled_until: dict[str, float] = {}
|
||
|
||
# HTTP/2 multiplexing — pool of parallel connections for DPI bypass.
|
||
# Iran's DPI shapes per-TCP-connection; N separate connections each
|
||
# get their own independent token bucket, giving ~N× throughput.
|
||
self._h2 = None
|
||
self._h2_pool: list = []
|
||
self._h2_pool_idx: int = 0
|
||
try:
|
||
from .h2_transport import H2Transport, H2_AVAILABLE
|
||
if H2_AVAILABLE:
|
||
try:
|
||
n_conns = max(1, int(config.get("h2_connections", 3)))
|
||
except (TypeError, ValueError):
|
||
n_conns = 3
|
||
no_sni = bool(config.get("no_sni", False))
|
||
try:
|
||
ping_interval = float(config.get("ping_interval", 0.2))
|
||
except (TypeError, ValueError):
|
||
ping_interval = 0.2
|
||
self._h2_pool = [
|
||
H2Transport(
|
||
self.connect_host, self.sni_host, self.verify_ssl,
|
||
sni_hosts=self._sni_hosts,
|
||
no_sni=no_sni,
|
||
ping_interval=ping_interval,
|
||
)
|
||
for _ in range(n_conns)
|
||
]
|
||
self._h2 = self._h2_pool[0] # primary; used for ping/reconnect
|
||
log.info(
|
||
"HTTP/2 multiplexing available — %d parallel connections "
|
||
"(each gets its own DPI token bucket)",
|
||
n_conns,
|
||
)
|
||
# H1 is now fallback-only — shrink the pool we keep warm.
|
||
# We still want a few ready for instant fallback when H2 hits
|
||
# a transient failure (cooldown window), but maintaining 30
|
||
# warm + 15 idle connections that are virtually never used
|
||
# wastes TLS handshakes and CPU.
|
||
self._warm_count = min(self._warm_count, 6)
|
||
self._pool_min_idle = min(self._pool_min_idle, 3)
|
||
except ImportError:
|
||
pass
|
||
|
||
if len(self._sni_hosts) > 1:
|
||
log.info("SNI rotation pool (%d): %s",
|
||
len(self._sni_hosts), ", ".join(self._sni_hosts))
|
||
if self._parallel_relay > 1:
|
||
log.info("Fan-out relay: %d parallel Apps Script instances per request",
|
||
self._parallel_relay)
|
||
log.info(
|
||
"Execution monitor enabled: reporting total every %.0fs",
|
||
self._execution_report_interval,
|
||
)
|
||
if self._batch_permanent_disable:
|
||
log.info(
|
||
"Batch DISABLED (enable_batch=false) — each request fires its own "
|
||
"H2 stream for N×2 KiB/s aggregate throughput"
|
||
)
|
||
else:
|
||
log.info(
|
||
"Batch config: micro=%.0fms macro=%.0fms max=%d sub_batch=%s",
|
||
self._batch_window_micro * 1000.0,
|
||
self._batch_window_macro * 1000.0,
|
||
self._batch_max,
|
||
"on" if self._sub_batch_enabled else "off",
|
||
)
|
||
|
||
# Exit node — optional second-hop relay with a non-Google exit IP.
|
||
# Useful for sites that block GCP/Apps Script IPs (e.g. ChatGPT).
|
||
en_cfg = config.get("exit_node") or {}
|
||
self._exit_node_enabled: bool = bool(en_cfg.get("enabled", False))
|
||
self._exit_node_provider: str = self._normalize_exit_node_provider(
|
||
en_cfg.get("provider"),
|
||
)
|
||
self._exit_node_url: str = self._resolve_exit_node_url(
|
||
self._exit_node_provider,
|
||
en_cfg,
|
||
)
|
||
self._exit_node_psk: str = str(en_cfg.get("psk") or "")
|
||
self._exit_node_mode: str = str(en_cfg.get("mode") or "selective").lower()
|
||
if self._exit_node_mode not in ("full", "selective"):
|
||
self._exit_node_mode = "selective"
|
||
self._exit_node_hosts: frozenset[str] = frozenset(
|
||
str(h).lower().strip().lstrip(".")
|
||
for h in (en_cfg.get("hosts") or [])
|
||
if h
|
||
)
|
||
if self._exit_node_enabled and self._exit_node_url:
|
||
log.info(
|
||
"Exit node enabled [mode=%s, provider=%s]: %s",
|
||
self._exit_node_mode,
|
||
self._exit_node_provider,
|
||
self._exit_node_url,
|
||
)
|
||
elif self._exit_node_enabled:
|
||
log.warning(
|
||
"Exit node is enabled but no URL is configured for provider '%s'",
|
||
self._exit_node_provider,
|
||
)
|
||
|
||
# Capability log for content encodings.
|
||
log.info("Response codecs: %s", codec.supported_encodings())
|
||
|
||
# ── helpers ───────────────────────────────────────────────────
|
||
|
||
@staticmethod
|
||
def _cfg_int(config: dict, key: str, default: int, *, minimum: int = 1) -> int:
|
||
try:
|
||
value = int(config.get(key, default))
|
||
except (TypeError, ValueError):
|
||
value = default
|
||
return max(minimum, value)
|
||
|
||
@staticmethod
|
||
def _cfg_float(config: dict, key: str, default: float,
|
||
*, minimum: float = 0.1) -> float:
|
||
try:
|
||
value = float(config.get(key, default))
|
||
except (TypeError, ValueError):
|
||
value = default
|
||
return max(minimum, value)
|
||
|
||
def _record_execution(self, sid: str, count: int = 1) -> None:
|
||
"""Record consumed Apps Script executions."""
|
||
if not sid or count <= 0:
|
||
return
|
||
self._exec_total += count
|
||
|
||
async def _execution_logger(self):
|
||
"""Log execution usage every N seconds, only when the count changed."""
|
||
interval = self._execution_report_interval
|
||
last_reported = -1
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(interval)
|
||
if self._exec_total != last_reported:
|
||
last_reported = self._exec_total
|
||
log.info("Apps Script executions used so far: %d", self._exec_total)
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as exc:
|
||
log.debug("Execution logger error: %s", exc)
|
||
|
||
@staticmethod
|
||
def _build_ssl_ctx(verify_ssl: bool) -> ssl.SSLContext:
|
||
ctx = ssl.create_default_context()
|
||
if certifi is not None:
|
||
try:
|
||
ctx.load_verify_locations(cafile=certifi.where())
|
||
except Exception:
|
||
pass
|
||
if not verify_ssl:
|
||
ctx.check_hostname = False
|
||
ctx.verify_mode = ssl.CERT_NONE
|
||
return ctx
|
||
|
||
def _ssl_ctx(self) -> ssl.SSLContext:
|
||
return self._ssl_context
|
||
|
||
def _h2_available(self) -> bool:
|
||
if not self._h2_pool or time.time() < self._h2_disabled_until:
|
||
return False
|
||
return any(t.is_connected for t in self._h2_pool)
|
||
|
||
def _pick_h2(self):
|
||
"""Round-robin pick a connected H2Transport from the pool.
|
||
|
||
Distributes relay requests across multiple TCP connections so each
|
||
benefits from its own independent DPI throughput budget.
|
||
Returns the primary transport when none are connected (caller will
|
||
trigger reconnection via the normal failure/cooldown path).
|
||
"""
|
||
pool = self._h2_pool
|
||
n = len(pool)
|
||
if not n:
|
||
return self._h2
|
||
for i in range(n):
|
||
t = pool[(self._h2_pool_idx + i) % n]
|
||
if t.is_connected:
|
||
self._h2_pool_idx = (self._h2_pool_idx + i + 1) % n
|
||
return t
|
||
# None connected — advance index and return primary
|
||
self._h2_pool_idx = (self._h2_pool_idx + 1) % n
|
||
return pool[0]
|
||
|
||
def _record_h2_success(self) -> None:
|
||
self._h2_failure_streak = 0
|
||
# Reset the generation guard so the *next* genuine drop is counted
|
||
# even if the connection happens to share its old generation key.
|
||
self._h2_last_failure_gen = -1
|
||
|
||
def _record_h2_failure(self, exc: Exception) -> None:
|
||
# De-dupe failures from a single connection drop event. When the
|
||
# H2 reader loop ends, every in-flight stream raises a transport
|
||
# error simultaneously — counting each as a separate failure trips
|
||
# the disable threshold from one drop with 5+ concurrent streams.
|
||
# Track failures per connection generation so a single drop counts
|
||
# at most once per H2 transport.
|
||
gen_key = -1
|
||
try:
|
||
if self._h2_pool:
|
||
# Use the sum of generations across the pool as a proxy
|
||
# for "have any connections been re-established since the
|
||
# last failure?". Bumps once per reconnect.
|
||
gen_key = sum(
|
||
getattr(t, "_conn_generation", 0) for t in self._h2_pool
|
||
)
|
||
elif self._h2 is not None:
|
||
gen_key = getattr(self._h2, "_conn_generation", 0)
|
||
except Exception:
|
||
gen_key = -1
|
||
if gen_key == self._h2_last_failure_gen and gen_key != -1:
|
||
# Same drop event — already counted.
|
||
return
|
||
self._h2_last_failure_gen = gen_key
|
||
|
||
self._h2_failure_streak += 1
|
||
# Extend the cooldown window on every failure so a burst of concurrent
|
||
# failures doesn't shorten the effective cooldown.
|
||
self._h2_disabled_until = max(
|
||
self._h2_disabled_until,
|
||
time.time() + self._H2_FAILURE_COOLDOWN,
|
||
)
|
||
# Log exactly once when the threshold is first crossed. Using ==
|
||
# (not >=) avoids re-logging on every subsequent failure from
|
||
# concurrent in-flight requests that all fail at the same moment.
|
||
if self._h2_failure_streak == self._H2_FAILURE_THRESHOLD:
|
||
log.warning(
|
||
"H2 temporarily disabled for %.0fs after %d consecutive failures (%s)",
|
||
self._H2_FAILURE_COOLDOWN,
|
||
self._h2_failure_streak,
|
||
type(exc).__name__,
|
||
)
|
||
|
||
@staticmethod
|
||
def _is_h2_transport_error(exc: BaseException) -> bool:
|
||
"""Return True only for genuine H2 *transport* failures.
|
||
|
||
Apps Script request timeouts (TimeoutError) and application-level
|
||
errors are NOT H2 transport failures — the connection may be fine.
|
||
Counting them pushes the failure streak toward the disable threshold
|
||
even when H2 is healthy, which causes unnecessary 15s fallbacks.
|
||
Only connection-level errors should disable H2.
|
||
"""
|
||
if isinstance(exc, asyncio.TimeoutError):
|
||
return False
|
||
if isinstance(exc, (ConnectionError, OSError, ssl.SSLError)):
|
||
return True
|
||
msg = str(exc).lower()
|
||
return any(k in msg for k in (
|
||
"connection closed", "connection lost", "stream error",
|
||
"alpn negotiation", "transport closed", "h2 reader",
|
||
"eof", "broken pipe",
|
||
))
|
||
|
||
def _stream_download_allowed(self, url: str) -> bool:
|
||
host = self._host_key(url)
|
||
if not host:
|
||
return True
|
||
until = self._stream_download_disabled_until.get(host, 0.0)
|
||
if until > time.time():
|
||
return False
|
||
if until:
|
||
self._stream_download_disabled_until.pop(host, None)
|
||
return True
|
||
|
||
def _mark_stream_download_failure(self, url: str, reason: str) -> None:
|
||
host = self._host_key(url)
|
||
if not host:
|
||
return
|
||
self._stream_download_disabled_until[host] = (
|
||
time.time() + self._DOWNLOAD_STREAM_COOLDOWN
|
||
)
|
||
log.warning(
|
||
"Parallel streaming disabled for host %s for %.0fs after failure (%s)",
|
||
host, self._DOWNLOAD_STREAM_COOLDOWN, reason,
|
||
)
|
||
|
||
def stream_download_allowed(self, url: str) -> bool:
|
||
return self._stream_download_allowed(url)
|
||
|
||
async def _open(self):
|
||
"""Open a TLS connection to the CDN.
|
||
|
||
- TCP_NODELAY is set on the underlying socket so small H2/H1 writes
|
||
aren't held back by Nagle's algorithm (up to ~40 ms per batch).
|
||
- The *server_hostname* parameter sets the **TLS SNI** extension;
|
||
we rotate across `self._sni_hosts` so DPI can't fingerprint
|
||
"always www.google.com" from the client side.
|
||
"""
|
||
loop = asyncio.get_running_loop()
|
||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||
sock.setblocking(False)
|
||
try:
|
||
await loop.sock_connect(sock, (self.connect_host, 443))
|
||
return await asyncio.open_connection(
|
||
sock=sock,
|
||
ssl=self._ssl_ctx(),
|
||
server_hostname=self._next_sni(),
|
||
)
|
||
except Exception:
|
||
try:
|
||
sock.close()
|
||
except Exception:
|
||
pass
|
||
raise
|
||
|
||
def _next_sni(self) -> str:
|
||
"""Round-robin the next SNI from the rotation pool."""
|
||
sni = self._sni_hosts[self._sni_idx % len(self._sni_hosts)]
|
||
self._sni_idx += 1
|
||
return sni
|
||
|
||
async def _ensure_sni_ranked(self) -> None:
|
||
if len(self._sni_hosts) <= 1:
|
||
return
|
||
task = self._sni_probe_task
|
||
if task is None:
|
||
task = self._spawn(self._rank_sni_hosts())
|
||
self._sni_probe_task = task
|
||
try:
|
||
await task
|
||
except Exception as exc:
|
||
log.debug("SNI probe failed: %s", exc)
|
||
|
||
async def _rank_sni_hosts(self) -> None:
|
||
sid = self._script_ids[0]
|
||
original = list(self._sni_hosts)
|
||
|
||
ranked: list[tuple[float, str]] = []
|
||
failed: list[str] = []
|
||
for sni in original:
|
||
result = await self._probe_sni_latency(sni, sid)
|
||
if result is None:
|
||
failed.append(sni)
|
||
else:
|
||
ranked.append((result, sni))
|
||
|
||
if not ranked:
|
||
return
|
||
|
||
ranked.sort(key=lambda item: item[0])
|
||
reordered = [sni for _, sni in ranked] + failed
|
||
if reordered == original:
|
||
log.info(
|
||
"SNI probe kept order: %s",
|
||
", ".join(f"{sni} ({ms:.0f}ms)" for ms, sni in ranked),
|
||
)
|
||
return
|
||
|
||
self._sni_hosts = reordered
|
||
self._sni_idx = 0
|
||
for _t in self._h2_pool:
|
||
_t._sni_hosts = list(reordered)
|
||
_t._sni_idx = 0
|
||
log.info(
|
||
"SNI pool re-ranked by local probe: %s",
|
||
", ".join(f"{sni} ({ms:.0f}ms)" for ms, sni in ranked),
|
||
)
|
||
if failed:
|
||
log.info("SNI probe timed out: %s", ", ".join(failed))
|
||
|
||
async def _probe_sni_latency(self, sni: str, sid: str) -> float | None:
|
||
samples: list[float] = []
|
||
for _ in range(2):
|
||
sample = await self._probe_sni_latency_once(sni, sid)
|
||
if sample is not None:
|
||
samples.append(sample)
|
||
if not samples:
|
||
return None
|
||
return statistics.median(samples)
|
||
|
||
async def _probe_sni_latency_once(self, sni: str, sid: str) -> float | None:
|
||
payload = json.dumps(
|
||
{"m": "GET", "u": "http://example.com/", "k": self.auth_key}
|
||
).encode()
|
||
path = f"/macros/s/{sid}/exec?hl={self._apps_script_lang}"
|
||
request = (
|
||
f"POST {path} HTTP/1.1\r\n"
|
||
f"Host: {self.http_host}\r\n"
|
||
"Content-Type: application/json\r\n"
|
||
"Accept: application/json,text/plain,*/*\r\n"
|
||
"Accept-Language: en-US,en;q=0.9\r\n"
|
||
f"Content-Length: {len(payload)}\r\n"
|
||
"Connection: close\r\n\r\n"
|
||
).encode() + payload
|
||
loop = asyncio.get_running_loop()
|
||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||
sock.setblocking(False)
|
||
started = time.perf_counter()
|
||
reader = None
|
||
writer = None
|
||
try:
|
||
await asyncio.wait_for(
|
||
loop.sock_connect(sock, (self.connect_host, 443)),
|
||
timeout=self._sni_probe_timeout,
|
||
)
|
||
reader, writer = await asyncio.wait_for(
|
||
asyncio.open_connection(
|
||
sock=sock,
|
||
ssl=self._ssl_ctx(),
|
||
server_hostname=sni,
|
||
),
|
||
timeout=self._sni_probe_timeout,
|
||
)
|
||
writer.write(request)
|
||
await asyncio.wait_for(writer.drain(), timeout=self._sni_probe_timeout)
|
||
|
||
head = b""
|
||
while b"\r\n\r\n" not in head and len(head) < 8192:
|
||
chunk = await asyncio.wait_for(
|
||
reader.read(512), timeout=self._sni_probe_timeout,
|
||
)
|
||
if not chunk:
|
||
break
|
||
head += chunk
|
||
if not head.startswith(b"HTTP/"):
|
||
return None
|
||
return (time.perf_counter() - started) * 1000
|
||
except Exception:
|
||
return None
|
||
finally:
|
||
if writer is not None:
|
||
writer.close()
|
||
try:
|
||
await writer.wait_closed()
|
||
except Exception:
|
||
pass
|
||
elif sock.fileno() != -1:
|
||
try:
|
||
sock.close()
|
||
except Exception:
|
||
pass
|
||
|
||
async def _acquire(self):
|
||
"""Get a healthy TLS connection from pool (TTL-checked) or open new."""
|
||
now = asyncio.get_running_loop().time()
|
||
async with self._pool_lock:
|
||
while self._pool:
|
||
reader, writer, created = self._pool.pop()
|
||
if (now - created) < self._conn_ttl and not reader.at_eof():
|
||
# Eagerly replace the connection we just took
|
||
self._spawn(self._add_conn_to_pool())
|
||
return reader, writer, created
|
||
try:
|
||
writer.close()
|
||
except Exception:
|
||
pass
|
||
reader, writer = await asyncio.wait_for(
|
||
self._open(), timeout=self._tls_connect_timeout
|
||
)
|
||
# Pool was empty — trigger aggressive background refill
|
||
if not self._refilling:
|
||
self._refilling = True
|
||
self._spawn(self._refill_pool())
|
||
return reader, writer, asyncio.get_running_loop().time()
|
||
|
||
async def _release(self, reader, writer, created):
|
||
"""Return a connection to the pool if still young and healthy."""
|
||
now = asyncio.get_running_loop().time()
|
||
if (now - created) >= self._conn_ttl or reader.at_eof():
|
||
try:
|
||
writer.close()
|
||
except Exception:
|
||
pass
|
||
return
|
||
async with self._pool_lock:
|
||
if len(self._pool) < self._pool_max:
|
||
self._pool.append((reader, writer, created))
|
||
else:
|
||
try:
|
||
writer.close()
|
||
except Exception:
|
||
pass
|
||
|
||
def _next_script_id(self) -> str:
|
||
"""Round-robin across script IDs for load distribution.
|
||
|
||
Skips script IDs currently in the short-term blacklist (failing
|
||
or slow) unless *all* are blacklisted, in which case we fall back
|
||
to plain round-robin so traffic can still flow.
|
||
"""
|
||
n = len(self._script_ids)
|
||
for _ in range(n):
|
||
sid = self._script_ids[self._script_idx % n]
|
||
self._script_idx += 1
|
||
if not self._is_sid_blacklisted(sid):
|
||
return sid
|
||
# All blacklisted — clear expired entries and fall back.
|
||
self._prune_blacklist(force=True)
|
||
sid = self._script_ids[self._script_idx % n]
|
||
self._script_idx += 1
|
||
return sid
|
||
|
||
def _is_sid_blacklisted(self, sid: str) -> bool:
|
||
until = self._sid_blacklist.get(sid, 0.0)
|
||
if until and until > time.time():
|
||
return True
|
||
if until:
|
||
self._sid_blacklist.pop(sid, None)
|
||
return False
|
||
|
||
def _blacklist_sid(self, sid: str, reason: str = "") -> None:
|
||
"""Blacklist a script ID for SCRIPT_BLACKLIST_TTL seconds."""
|
||
if len(self._script_ids) <= 1:
|
||
return # Nothing to fall back to — blacklist would be pointless.
|
||
self._sid_blacklist[sid] = time.time() + self._blacklist_ttl
|
||
log.warning("Blacklisted script %s for %ds%s",
|
||
_mask_sid(sid),
|
||
int(self._blacklist_ttl),
|
||
f" ({reason})" if reason else "")
|
||
|
||
def _prune_blacklist(self, force: bool = False) -> None:
|
||
now = time.time()
|
||
for sid, until in list(self._sid_blacklist.items()):
|
||
if force or until <= now:
|
||
self._sid_blacklist.pop(sid, None)
|
||
|
||
def _pick_fanout_sids(self, key: str | None) -> list[str]:
|
||
"""Pick up to `parallel_relay` distinct non-blacklisted script IDs.
|
||
|
||
The first ID is the stable per-host choice (same as single-shot
|
||
routing); the rest are filled from the remaining pool. This keeps
|
||
session-sensitive hosts pinned to one script while still racing
|
||
extras for lower tail latency.
|
||
"""
|
||
if self._parallel_relay <= 1 or len(self._script_ids) <= 1:
|
||
return [self._script_id_for_key(key)]
|
||
primary = self._script_id_for_key(key)
|
||
picked = [primary]
|
||
others = [s for s in self._script_ids
|
||
if s != primary and not self._is_sid_blacklisted(s)]
|
||
# Round-robin-ish selection from `others`
|
||
for sid in others:
|
||
if len(picked) >= self._parallel_relay:
|
||
break
|
||
picked.append(sid)
|
||
return picked
|
||
|
||
@staticmethod
|
||
def _host_key(url_or_host: str | None) -> str:
|
||
"""Return a stable routing key for a URL or host string."""
|
||
if not url_or_host:
|
||
return ""
|
||
parsed = urlparse(url_or_host if "://" in url_or_host else f"https://{url_or_host}")
|
||
host = parsed.hostname or url_or_host
|
||
return host.lower().rstrip(".")
|
||
|
||
@classmethod
|
||
def _coalesce_key(cls, url: str, headers: dict | None) -> str:
|
||
key = [url]
|
||
if headers:
|
||
lowered = {str(k).lower(): str(v) for k, v in headers.items()}
|
||
for name in cls._COALESCE_VARY_HEADERS:
|
||
value = lowered.get(name)
|
||
if value:
|
||
key.append(f"{name}={value}")
|
||
return "\n".join(key)
|
||
|
||
@classmethod
|
||
def _retry_attempts_for_payload(cls, payload: dict) -> int:
|
||
method = str(payload.get("m", "GET")).upper()
|
||
return 2 if method in cls._SAFE_RETRY_METHODS else 1
|
||
|
||
@staticmethod
|
||
def _render_streaming_headers(resp_headers: dict, total_size: int) -> bytes:
|
||
lines = ["HTTP/1.1 200 OK"]
|
||
skip = {
|
||
"transfer-encoding",
|
||
"connection",
|
||
"keep-alive",
|
||
"content-length",
|
||
"content-range",
|
||
}
|
||
for key, value in resp_headers.items():
|
||
if key.lower() in skip:
|
||
continue
|
||
lines.append(f"{key}: {value}")
|
||
lines.append(f"Content-Length: {total_size}")
|
||
lines.append("")
|
||
lines.append("")
|
||
return "\r\n".join(lines).encode()
|
||
|
||
async def _relay_payload_h1(self, payload: dict) -> bytes:
|
||
attempts = self._retry_attempts_for_payload(payload)
|
||
async with self._semaphore:
|
||
for attempt in range(attempts):
|
||
try:
|
||
return await asyncio.wait_for(
|
||
self._relay_single(payload), timeout=self._relay_timeout,
|
||
)
|
||
except Exception as exc:
|
||
if attempt < attempts - 1:
|
||
log.debug(
|
||
"H1 relay attempt %d failed (%s: %s), retrying",
|
||
attempt + 1, type(exc).__name__, exc,
|
||
)
|
||
await self._flush_pool()
|
||
else:
|
||
raise
|
||
|
||
async def _range_probe(self, url: str, headers: dict, start_off: int,
|
||
end_off: int, *, max_tries: int = 3) -> bytes:
|
||
probe_headers = dict(headers) if headers else {}
|
||
probe_headers["Range"] = f"bytes={start_off}-{end_off}"
|
||
probe_payload = self._build_payload("GET", url, probe_headers, b"")
|
||
last_raw = b""
|
||
last_status = 0
|
||
for attempt in range(max_tries):
|
||
try:
|
||
last_raw = await self._relay_payload_h1(probe_payload)
|
||
except Exception as exc:
|
||
if attempt == max_tries - 1:
|
||
raise
|
||
log.warning(
|
||
"Initial range probe %d-%d retry %d/%d failed: %r",
|
||
start_off, end_off, attempt + 1, max_tries, exc,
|
||
)
|
||
await asyncio.sleep(0.3 * (attempt + 1))
|
||
continue
|
||
|
||
last_status, _, _ = split_raw_response(last_raw)
|
||
if last_status == 206 or last_status < 500:
|
||
return last_raw
|
||
if attempt < max_tries - 1:
|
||
log.warning(
|
||
"Initial range probe %d-%d retry %d/%d: status %d",
|
||
start_off, end_off, attempt + 1, max_tries, last_status,
|
||
)
|
||
await asyncio.sleep(0.3 * (attempt + 1))
|
||
return last_raw
|
||
|
||
# ── Per-host stats ────────────────────────────────────────────
|
||
|
||
def _record_site(self, url: str, bytes_: int, latency_ns: int,
|
||
errored: bool) -> None:
|
||
host = self._host_key(url)
|
||
if not host:
|
||
return
|
||
stat = self._per_site.get(host)
|
||
if stat is None:
|
||
stat = HostStat()
|
||
self._per_site[host] = stat
|
||
stat.requests += 1
|
||
stat.bytes += max(0, int(bytes_))
|
||
stat.total_latency_ns += max(0, int(latency_ns))
|
||
if errored:
|
||
stat.errors += 1
|
||
|
||
def stats_snapshot(self) -> dict:
|
||
"""Return a point-in-time snapshot of traffic + script health."""
|
||
per_site = []
|
||
for host, s in self._per_site.items():
|
||
avg_ms = (s.total_latency_ns / s.requests / 1e6) if s.requests else 0.0
|
||
per_site.append({
|
||
"host": host,
|
||
"requests": s.requests,
|
||
"errors": s.errors,
|
||
"bytes": s.bytes,
|
||
"avg_ms": round(avg_ms, 1),
|
||
})
|
||
per_site.sort(key=lambda x: x["bytes"], reverse=True)
|
||
now = time.time()
|
||
blacklisted = [
|
||
{"sid": _mask_sid(sid),
|
||
"expires_in_s": int(max(0, until - now))}
|
||
for sid, until in self._sid_blacklist.items() if until > now
|
||
]
|
||
return {
|
||
"per_site": per_site,
|
||
"blacklisted_scripts": blacklisted,
|
||
"sni_rotation": list(self._sni_hosts),
|
||
"parallel_relay": self._parallel_relay,
|
||
}
|
||
|
||
async def _stats_logger(self):
|
||
"""Periodically log top hosts by bytes. DEBUG-level, low overhead."""
|
||
interval = STATS_LOG_INTERVAL
|
||
top_n = STATS_LOG_TOP_N
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(interval)
|
||
if not log.isEnabledFor(logging.DEBUG) or not self._per_site:
|
||
continue
|
||
snap = self.stats_snapshot()
|
||
top = snap["per_site"][:top_n]
|
||
log.debug("── Per-host stats (top %d by bytes) ──", len(top))
|
||
for row in top:
|
||
log.debug(
|
||
" %-40s %5d req %2d err %8d KB avg %7.1f ms",
|
||
row["host"][:40], row["requests"], row["errors"],
|
||
row["bytes"] // 1024, row["avg_ms"],
|
||
)
|
||
if snap["blacklisted_scripts"]:
|
||
log.debug(" blacklisted scripts: %s",
|
||
", ".join(f"{_mask_sid(b['sid'])} ({b['expires_in_s']}s)"
|
||
for b in snap["blacklisted_scripts"]))
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
log.debug("Stats logger error: %s", e)
|
||
|
||
def _script_id_for_key(self, key: str | None = None) -> str:
|
||
"""Pick a stable Apps Script ID for a host or fallback to round-robin.
|
||
|
||
When multiple deployments are configured, using a stable mapping per
|
||
host reduces IP/session churn for sites that are sensitive to endpoint
|
||
changes. If no key is available, we keep the older round-robin fallback
|
||
so warmup/keepalive traffic still distributes normally.
|
||
|
||
Blacklisted IDs are skipped by probing forward in the list until a
|
||
healthy one is found; if none, the stable pick is returned anyway.
|
||
"""
|
||
if len(self._script_ids) == 1:
|
||
return self._script_ids[0]
|
||
if not key:
|
||
return self._next_script_id()
|
||
digest = hashlib.sha1(key.encode("utf-8")).digest()
|
||
base = int.from_bytes(digest[:4], "big") % len(self._script_ids)
|
||
n = len(self._script_ids)
|
||
for offset in range(n):
|
||
sid = self._script_ids[(base + offset) % n]
|
||
if not self._is_sid_blacklisted(sid):
|
||
return sid
|
||
return self._script_ids[base]
|
||
|
||
def _exec_path(self, url_or_host: str | None = None) -> str:
|
||
"""Get the Apps Script endpoint path (/dev or /exec)."""
|
||
sid = self._script_id_for_key(self._host_key(url_or_host))
|
||
return self._exec_path_for_sid(sid)
|
||
|
||
def _exec_path_for_sid(self, sid: str) -> str:
|
||
"""Build the /macros/s/<sid>/(dev|exec) path for a specific script ID."""
|
||
endpoint = "dev" if self._dev_available else "exec"
|
||
# Force Google Apps Script UI/errors to English for stable diagnostics.
|
||
return f"/macros/s/{sid}/{endpoint}?hl={self._apps_script_lang}"
|
||
|
||
def _apps_script_headers(self) -> dict[str, str]:
|
||
"""Headers for Apps Script relay calls (control-plane, not target origin)."""
|
||
return {
|
||
"content-type": "application/json",
|
||
"accept": "application/json,text/plain,*/*",
|
||
"accept-language": "en-US,en;q=0.9",
|
||
}
|
||
async def _flush_pool(self):
|
||
"""Close all pooled connections (they may be stale after errors)."""
|
||
async with self._pool_lock:
|
||
for _, writer, _ in self._pool:
|
||
try:
|
||
writer.close()
|
||
except Exception:
|
||
pass
|
||
self._pool.clear()
|
||
|
||
async def _refill_pool(self):
|
||
"""Background: open connections in parallel to refill empty pool."""
|
||
try:
|
||
coros = [self._add_conn_to_pool() for _ in range(8)]
|
||
await asyncio.gather(*coros, return_exceptions=True)
|
||
finally:
|
||
self._refilling = False
|
||
|
||
async def _add_conn_to_pool(self):
|
||
"""Open one TLS connection and add it to the pool."""
|
||
try:
|
||
r, w = await asyncio.wait_for(self._open(), timeout=5)
|
||
t = asyncio.get_running_loop().time()
|
||
async with self._pool_lock:
|
||
if len(self._pool) < self._pool_max:
|
||
self._pool.append((r, w, t))
|
||
else:
|
||
try:
|
||
w.close()
|
||
except Exception:
|
||
pass
|
||
except Exception:
|
||
pass
|
||
|
||
async def _pool_maintenance(self):
|
||
"""Continuously maintain healthy pool levels in background."""
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(3)
|
||
now = asyncio.get_running_loop().time()
|
||
|
||
# Purge expired / dead connections
|
||
async with self._pool_lock:
|
||
alive = []
|
||
for r, w, t in self._pool:
|
||
if (now - t) < self._conn_ttl and not r.at_eof():
|
||
alive.append((r, w, t))
|
||
else:
|
||
try:
|
||
w.close()
|
||
except Exception:
|
||
pass
|
||
self._pool = alive
|
||
idle = len(self._pool)
|
||
|
||
# Refill if below minimum idle threshold
|
||
needed = max(0, self._pool_min_idle - idle)
|
||
if needed > 0:
|
||
coros = [self._add_conn_to_pool()
|
||
for _ in range(min(needed, 5))]
|
||
await asyncio.gather(*coros, return_exceptions=True)
|
||
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception:
|
||
pass
|
||
|
||
async def _warm_pool(self):
|
||
"""Pre-open TLS connections in the background. Never blocks relay()."""
|
||
if self._warmed:
|
||
return
|
||
self._warmed = True
|
||
if self._sni_probe_task is None and len(self._sni_hosts) > 1:
|
||
self._sni_probe_task = self._spawn(self._rank_sni_hosts())
|
||
self._warm_task = self._spawn(self._do_warm())
|
||
# Start continuous pool maintenance
|
||
if self._maintenance_task is None:
|
||
self._maintenance_task = self._spawn(self._pool_maintenance())
|
||
# Periodic per-host stats logger (opt-in via log level)
|
||
if self._stats_task is None:
|
||
self._stats_task = self._spawn(self._stats_logger())
|
||
if self._execution_task is None:
|
||
self._execution_task = self._spawn(self._execution_logger())
|
||
# Start H2 connection (runs alongside H1 pool)
|
||
if self._h2:
|
||
self._spawn(self._h2_connect_and_warm())
|
||
# H1 container keepalive — runs unconditionally so the Apps Script
|
||
# container never goes cold even when H2 is unavailable. When H2 IS
|
||
# active its _keepalive_loop skips the ping; they do not double-fire.
|
||
self._spawn(self._h1_container_keepalive())
|
||
|
||
async def wait_until_warm(self, timeout: float | None = None) -> bool:
|
||
"""Start warmup and wait until the initial pool-open phase finishes.
|
||
|
||
Returns True if warmup finished before timeout, else False.
|
||
"""
|
||
await self._warm_pool()
|
||
if self._pool_ready.is_set():
|
||
return True
|
||
try:
|
||
if timeout is None or timeout <= 0:
|
||
await self._pool_ready.wait()
|
||
else:
|
||
await asyncio.wait_for(self._pool_ready.wait(), timeout=timeout)
|
||
return True
|
||
except asyncio.TimeoutError:
|
||
return False
|
||
|
||
def _spawn(self, coro) -> asyncio.Task:
|
||
"""Create a task and keep a strong reference for clean cancellation."""
|
||
task = asyncio.create_task(coro)
|
||
self._bg_tasks.add(task)
|
||
task.add_done_callback(self._bg_tasks.discard)
|
||
return task
|
||
|
||
async def close(self):
|
||
"""Cancel background tasks and close all pooled / H2 connections."""
|
||
tasks = list(self._bg_tasks)
|
||
for task in tasks:
|
||
task.cancel()
|
||
if tasks:
|
||
await asyncio.gather(*tasks, return_exceptions=True)
|
||
self._bg_tasks.clear()
|
||
|
||
self._warm_task = None
|
||
self._maintenance_task = None
|
||
self._stats_task = None
|
||
self._execution_task = None
|
||
self._keepalive_task = None
|
||
|
||
await self._flush_pool()
|
||
|
||
for _t in self._h2_pool:
|
||
try:
|
||
await _t.close()
|
||
except Exception as exc:
|
||
log.debug("h2 pool close: %s", exc)
|
||
|
||
async def _h2_connect(self):
|
||
"""Connect all HTTP/2 transports in the pool."""
|
||
if not self._h2_pool:
|
||
return
|
||
if time.time() < self._h2_disabled_until:
|
||
return
|
||
try:
|
||
await self._ensure_sni_ranked()
|
||
results = await asyncio.gather(
|
||
*[t.ensure_connected() for t in self._h2_pool],
|
||
return_exceptions=True,
|
||
)
|
||
connected = sum(1 for r in results if not isinstance(r, Exception))
|
||
if connected > 0:
|
||
self._record_h2_success()
|
||
log.info(
|
||
"H2 multiplexing active — %d/%d connections live",
|
||
connected, len(self._h2_pool),
|
||
)
|
||
else:
|
||
exc = next(r for r in results if isinstance(r, Exception))
|
||
self._record_h2_failure(exc)
|
||
log.warning(
|
||
"H2 connect failed (%s: %s), using H1 pool fallback",
|
||
type(exc).__name__, exc or "(no details)",
|
||
)
|
||
except Exception as e:
|
||
self._record_h2_failure(e)
|
||
log.warning(
|
||
"H2 connect failed (%s: %s), using H1 pool fallback",
|
||
type(e).__name__,
|
||
e or "(no details)",
|
||
)
|
||
|
||
async def _h2_connect_and_warm(self):
|
||
"""Connect H2, pre-warm the Apps Script container, start keepalive."""
|
||
await self._h2_connect()
|
||
if self._h2_available():
|
||
self._spawn(self._prewarm_script())
|
||
# Always start keepalive — even on startup failure it will retry H2
|
||
# once the cooldown expires instead of leaving H1-only permanently.
|
||
if self._keepalive_task is None or self._keepalive_task.done():
|
||
self._keepalive_task = self._spawn(self._keepalive_loop())
|
||
|
||
async def _prewarm_script(self):
|
||
"""Pre-warm Apps Script and detect /dev fast path (no redirect)."""
|
||
payload = json.dumps(
|
||
{"m": "GET", "u": "http://example.com/", "k": self.auth_key}
|
||
).encode()
|
||
hdrs = self._apps_script_headers()
|
||
sid = self._script_ids[0]
|
||
|
||
# Test /dev endpoint — returns data inline (no 302 redirect).
|
||
# If it works, saves ~400ms per request by eliminating one round trip.
|
||
try:
|
||
dev_path = f"/macros/s/{sid}/dev?hl={self._apps_script_lang}"
|
||
t0 = time.perf_counter()
|
||
self._record_execution(sid)
|
||
status, _, body = await asyncio.wait_for(
|
||
self._h2.request(
|
||
method="POST", path=dev_path, host=self.http_host,
|
||
headers=hdrs, body=payload,
|
||
),
|
||
timeout=15,
|
||
)
|
||
dt = (time.perf_counter() - t0) * 1000
|
||
data = load_relay_json(body.decode(errors="replace"))
|
||
if "s" in data:
|
||
self._dev_available = True
|
||
log.info("/dev fast path active (%.0fms, no redirect)", dt)
|
||
return
|
||
except Exception as e:
|
||
log.debug("/dev test failed: %s", e)
|
||
|
||
# Fallback: warm up with /exec
|
||
try:
|
||
exec_path = f"/macros/s/{sid}/exec?hl={self._apps_script_lang}"
|
||
t0 = time.perf_counter()
|
||
self._record_execution(sid)
|
||
await asyncio.wait_for(
|
||
self._h2.request(
|
||
method="POST", path=exec_path, host=self.http_host,
|
||
headers=hdrs, body=payload,
|
||
),
|
||
timeout=15,
|
||
)
|
||
dt = (time.perf_counter() - t0) * 1000
|
||
log.info("Apps Script pre-warmed in %.0fms", dt)
|
||
except Exception as e:
|
||
log.debug("Pre-warm failed: %s", e)
|
||
|
||
async def _keepalive_loop(self):
|
||
"""Send periodic pings to keep Apps Script warm + H2 connection alive."""
|
||
while True:
|
||
try:
|
||
# 60s cadence: Iran DPI/NAT can drop idle connections in ~30-60s.
|
||
# Pinging every 60s keeps all pool members alive without burning
|
||
# significant Apps Script quota.
|
||
await asyncio.sleep(60)
|
||
|
||
# If H2 is absent or still in cooldown, skip this tick.
|
||
if self._h2 is None or time.time() < self._h2_disabled_until:
|
||
continue
|
||
|
||
# Reconnect any disconnected pool members.
|
||
for _t in list(self._h2_pool):
|
||
if not _t.is_connected:
|
||
try:
|
||
await asyncio.wait_for(
|
||
_t.reconnect(),
|
||
timeout=max(self._tls_connect_timeout, 8.0),
|
||
)
|
||
self._record_h2_success()
|
||
log.info("H2 connection re-established")
|
||
except Exception as exc:
|
||
# Keepalive reconnect failures are background recovery
|
||
# attempts — do NOT count them toward the disable
|
||
# threshold or healthy traffic gets penalised.
|
||
log.debug("H2 background reconnect failed: %s", exc)
|
||
|
||
if not any(t.is_connected for t in self._h2_pool):
|
||
continue # all transports down — skip ping
|
||
|
||
# H2 PING frame to every connected pool member.
|
||
# This tells each OS/DPI that the TCP connection is still in use,
|
||
# preventing the 30-60s idle-reset that Iran DPI applies.
|
||
for _t in self._h2_pool:
|
||
if _t.is_connected:
|
||
try:
|
||
await _t.ping()
|
||
except Exception:
|
||
pass
|
||
|
||
# Apps Script keepalive — warm the container
|
||
payload = {"m": "GET", "u": "http://example.com/", "k": self.auth_key}
|
||
sid = self._script_id_for_key(self._host_key("example.com"))
|
||
path = self._exec_path_for_sid(sid)
|
||
t0 = time.perf_counter()
|
||
self._record_execution(sid)
|
||
await asyncio.wait_for(
|
||
self._h2.request(
|
||
method="POST", path=path, host=self.http_host,
|
||
headers=self._apps_script_headers(),
|
||
body=json.dumps(payload).encode(),
|
||
),
|
||
timeout=20,
|
||
)
|
||
dt = (time.perf_counter() - t0) * 1000
|
||
log.debug("Keepalive ping: %.0fms", dt)
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as e:
|
||
log.debug("Keepalive failed: %s", e)
|
||
|
||
async def _h1_container_keepalive(self):
|
||
"""Keep the Apps Script container warm via H1 when H2 keepalive is absent.
|
||
|
||
H2's _keepalive_loop handles pings when H2 is connected. When H2 is
|
||
unavailable (library not installed, connection dropped) this coroutine
|
||
takes over so the container never goes cold and causes slow cold-starts
|
||
on the first video / streaming request after an idle period.
|
||
"""
|
||
while True:
|
||
try:
|
||
await asyncio.sleep(240) # same cadence as H2 keepalive
|
||
if self._h2_available():
|
||
continue # H2 keepalive is already pinging, skip
|
||
payload = self._build_payload(
|
||
"GET", "http://example.com/", {}, b""
|
||
)
|
||
t0 = time.perf_counter()
|
||
# _relay_payload_h1 has its own per-attempt timeout internally;
|
||
# no outer wait_for needed (and adding one with a shorter
|
||
# timeout would cancel valid in-progress relays early).
|
||
await self._relay_payload_h1(payload)
|
||
dt = (time.perf_counter() - t0) * 1000
|
||
log.debug("H1 container keepalive: %.0fms", dt)
|
||
except asyncio.CancelledError:
|
||
break
|
||
except Exception as exc:
|
||
log.debug("H1 container keepalive failed: %s", exc)
|
||
|
||
async def _do_warm(self):
|
||
"""Open WARM_POOL_COUNT connections in parallel — failures are fine."""
|
||
await self._ensure_sni_ranked()
|
||
count = self._warm_count
|
||
coros = [self._add_conn_to_pool() for _ in range(count)]
|
||
results = await asyncio.gather(*coros, return_exceptions=True)
|
||
opened = sum(1 for r in results if not isinstance(r, Exception))
|
||
log.info("Pre-warmed %d/%d TLS connections", opened, count)
|
||
# Signal that at least the pool-open phase finished so relay() can
|
||
# stop waiting on the first request.
|
||
self._pool_ready.set()
|
||
|
||
async def _reconnect_pool_members(self) -> None:
|
||
"""Background: reconnect any H2 pool members that dropped.
|
||
|
||
Called after a transport error in the relay path so connections are
|
||
recovered promptly instead of waiting for the next keepalive tick.
|
||
Does NOT increment the failure streak — this is a recovery action.
|
||
"""
|
||
for _t in self._h2_pool:
|
||
if not _t.is_connected:
|
||
try:
|
||
await asyncio.wait_for(
|
||
_t.reconnect(),
|
||
timeout=max(self._tls_connect_timeout, 8.0),
|
||
)
|
||
log.debug("H2 pool member recovered")
|
||
except Exception as exc:
|
||
log.debug("H2 pool member reconnect failed: %s", exc)
|
||
|
||
def _auth_header(self) -> str:
|
||
return f"X-Auth-Key: {self.auth_key}\r\n" if self.auth_key else ""
|
||
|
||
# ── Exit node relay ───────────────────────────────────────────
|
||
|
||
@staticmethod
|
||
def _normalize_exit_node_provider(raw: object) -> str:
|
||
provider = str(raw or "custom").strip().lower()
|
||
aliases = {
|
||
"cloudflare_worker": "cloudflare",
|
||
"worker": "cloudflare",
|
||
"cf": "cloudflare",
|
||
"deno_deploy": "deno",
|
||
"self_hosted": "vps",
|
||
"self-hosted": "vps",
|
||
"selfhosted": "vps",
|
||
"server": "vps",
|
||
}
|
||
return aliases.get(provider, provider or "custom")
|
||
|
||
@classmethod
|
||
def _resolve_exit_node_url(cls, provider: str,
|
||
en_cfg: dict[str, object]) -> str:
|
||
providers = en_cfg.get("providers")
|
||
if not isinstance(providers, dict):
|
||
providers = {}
|
||
|
||
def _pick_from(mapping: dict[str, object], *keys: str) -> str:
|
||
for key in keys:
|
||
value = mapping.get(key)
|
||
if isinstance(value, str):
|
||
value = value.strip()
|
||
if value:
|
||
return value.rstrip("/")
|
||
return ""
|
||
|
||
# Beginner-first: one URL field is enough for all providers.
|
||
direct = _pick_from(en_cfg, "url")
|
||
if direct:
|
||
return direct
|
||
|
||
if provider == "cloudflare":
|
||
selected = _pick_from(
|
||
en_cfg, "cloudflare_url", "worker_url", "cf_url",
|
||
) or _pick_from(
|
||
providers, "cloudflare", "cloudflare_worker", "worker", "cf",
|
||
)
|
||
elif provider == "deno":
|
||
selected = _pick_from(en_cfg, "deno_url") or _pick_from(
|
||
providers, "deno", "deno_deploy",
|
||
)
|
||
elif provider == "vps":
|
||
selected = _pick_from(
|
||
en_cfg, "vps_url", "server_url", "self_hosted_url",
|
||
) or _pick_from(
|
||
providers, "vps", "self_hosted", "server",
|
||
)
|
||
else:
|
||
selected = ""
|
||
|
||
if selected:
|
||
return selected
|
||
# Backward compatibility for older config format.
|
||
return _pick_from(en_cfg, "relay_url")
|
||
|
||
def _exit_node_matches(self, url: str) -> bool:
|
||
"""Return True if this URL should be routed through the exit node."""
|
||
if not self._exit_node_enabled or not self._exit_node_url:
|
||
return False
|
||
if self._exit_node_mode == "full":
|
||
return True
|
||
# selective: check if destination hostname matches configured list
|
||
host = self._host_key(url)
|
||
if not host:
|
||
return False
|
||
for pattern in self._exit_node_hosts:
|
||
if host == pattern or host.endswith("." + pattern):
|
||
return True
|
||
return False
|
||
|
||
async def _relay_via_exit_node(self, payload: dict) -> bytes:
|
||
"""Chain: Apps Script → edge relay (exit node) → Destination.
|
||
|
||
Traffic path:
|
||
Client → [domain fronting TLS] → Apps Script (Google)
|
||
→ [UrlFetchApp.fetch] → exit node (non-Google IP)
|
||
→ [fetch()] → Destination
|
||
|
||
This preserves the DPI bypass (Apps Script is always the outbound
|
||
connection from the client's perspective) while giving the destination
|
||
a non-Google exit IP — fixing Cloudflare Turnstile, ChatGPT, etc.
|
||
|
||
The inner payload going to the exit node is base64-encoded and sent as the
|
||
body of the outer Apps Script relay call, so Apps Script POSTs it to
|
||
the exit node URL on our behalf.
|
||
"""
|
||
# Build inner payload: what the exit node will execute.
|
||
# Strip accept-encoding from the inner headers so the target site
|
||
# returns an uncompressed body. Exit nodes (CF Worker, VPS) make
|
||
# plain Python/JS fetch() calls that don't auto-decompress, so a
|
||
# compressed response body would be forwarded as garbled bytes.
|
||
inner = dict(payload)
|
||
inner["k"] = self._exit_node_psk
|
||
if isinstance(inner.get("h"), dict):
|
||
inner["h"] = {
|
||
k: v for k, v in inner["h"].items()
|
||
if k.lower() != "accept-encoding"
|
||
}
|
||
inner_json = json.dumps(inner).encode()
|
||
|
||
# Build outer payload: what Apps Script will fetch
|
||
# Apps Script does: UrlFetchApp.fetch(exit_node_url, { method: "POST", payload: inner_json })
|
||
outer = self._build_payload(
|
||
"POST",
|
||
self._exit_node_url,
|
||
{"Content-Type": "application/json"},
|
||
inner_json,
|
||
)
|
||
# Override content-type explicitly so Apps Script sets it correctly
|
||
outer["ct"] = "application/json"
|
||
|
||
log.debug(
|
||
"Exit node chain: Apps Script → %s → %s",
|
||
self._exit_node_url.split("//", 1)[-1][:50],
|
||
payload.get("u", "")[:60],
|
||
)
|
||
|
||
# Send through the batch collector so exit-node requests are coalesced
|
||
# into fetchAll() alongside other concurrent requests, reducing Apps
|
||
# Script quota usage. _relay_with_retry bypasses batching entirely.
|
||
raw = await self._batch_submit(outer)
|
||
|
||
_, _, vps_relay_bytes = split_raw_response(raw)
|
||
result = parse_relay_response(vps_relay_bytes, self._max_response_body_bytes)
|
||
log.debug("Exit node relay OK: %s", payload.get("u", "")[:80])
|
||
return result
|
||
|
||
# ── Apps Script relay (apps_script mode) ──────────────────────
|
||
|
||
async def relay(self, method: str, url: str,
|
||
headers: dict, body: bytes = b"") -> bytes:
|
||
"""Relay an HTTP request through Apps Script.
|
||
|
||
Features:
|
||
- Pre-warms TLS connections on first call
|
||
- Coalesces concurrent identical GET requests
|
||
- Batches concurrent calls via fetchAll() (40ms window)
|
||
- Retries once on connection failure
|
||
- Concurrency-limited via semaphore
|
||
|
||
Returns a raw HTTP response (status + headers + body).
|
||
"""
|
||
if not self._warmed:
|
||
await self._warm_pool()
|
||
|
||
# On the very first request, wait up to one TLS-connect-timeout for the
|
||
# pool to have at least one open connection. This prevents the first
|
||
# browser request from racing onto a completely cold pool. The wait is
|
||
# capped so a slow network never blocks the user indefinitely — the
|
||
# normal retry/fallback path handles it from there.
|
||
if not self._pool_ready.is_set():
|
||
try:
|
||
await asyncio.wait_for(
|
||
asyncio.shield(self._pool_ready.wait()),
|
||
timeout=self._tls_connect_timeout,
|
||
)
|
||
except asyncio.TimeoutError:
|
||
log.debug("Pool warm timeout — proceeding with cold pool")
|
||
|
||
# SABR / videoplayback: strip quality-track selection fields (field 3,
|
||
# tag 0x1a) from the top-level protobuf before relaying. Those entries
|
||
# ask googlevideo to bundle multiple simultaneous quality tracks into one
|
||
# response, which easily exceeds Apps Script UrlFetchApp's ~10 MB buffer
|
||
# and produces "Response too large" → 502. Removing them forces a
|
||
# single-track response that stays within the limit.
|
||
if method == "POST" and body and "/videoplayback" in url:
|
||
stripped = self._strip_sabr_quality_tracks(body)
|
||
if stripped != body:
|
||
log.debug(
|
||
"SABR strip: removed %d quality-track bytes from %s",
|
||
len(body) - len(stripped),
|
||
url.split("?")[0][-60:],
|
||
)
|
||
body = stripped
|
||
|
||
payload = self._build_payload(method, url, headers, body)
|
||
|
||
# Exit node short-circuit: route to non-Google IP before Apps Script
|
||
if self._exit_node_matches(url):
|
||
t0 = time.perf_counter()
|
||
errored = False
|
||
try:
|
||
return await asyncio.wait_for(
|
||
self._relay_via_exit_node(payload),
|
||
timeout=self._relay_timeout + self._tls_connect_timeout,
|
||
)
|
||
except Exception as exc:
|
||
errored = True
|
||
log.warning(
|
||
"Exit node failed for %s (%s: %s), falling back to Apps Script",
|
||
url[:60], type(exc).__name__, exc,
|
||
)
|
||
finally:
|
||
latency_ns = int((time.perf_counter() - t0) * 1e9)
|
||
self._record_site(url, 0, latency_ns, errored)
|
||
# fall through to normal Apps Script relay on failure
|
||
|
||
t0 = time.perf_counter()
|
||
errored = False
|
||
result: bytes = b""
|
||
try:
|
||
# Stateful/browser-navigation requests should preserve exact ordering
|
||
# and header context; batching/coalescing is reserved for static fetches.
|
||
if self._is_stateful_request(method, url, headers, body):
|
||
result = await self._relay_with_retry(payload)
|
||
return result
|
||
|
||
# Coalesce concurrent GETs for the same URL.
|
||
# CRITICAL: do NOT coalesce when a Range header is present —
|
||
# parallel range downloads MUST each hit the server independently.
|
||
has_range = bool(self._header_value(headers, "range"))
|
||
if method == "GET" and not body and not has_range:
|
||
result = await self._coalesced_submit(
|
||
self._coalesce_key(url, headers), payload,
|
||
)
|
||
return result
|
||
|
||
result = await self._batch_submit(payload)
|
||
return result
|
||
except Exception:
|
||
errored = True
|
||
raise
|
||
finally:
|
||
latency_ns = int((time.perf_counter() - t0) * 1e9)
|
||
self._record_site(url, len(result), latency_ns, errored)
|
||
|
||
async def _coalesced_submit(self, key: str, payload: dict) -> bytes:
|
||
"""Dedup concurrent requests for the same URL (no Range header).
|
||
|
||
Uses `_batch_lock` to atomically check-and-append, preventing a
|
||
race where the owning task's `finally` pops the entry between
|
||
the check and append by a second task.
|
||
"""
|
||
loop = asyncio.get_running_loop()
|
||
async with self._batch_lock:
|
||
waiters = self._coalesce.get(key)
|
||
if waiters is not None:
|
||
future = loop.create_future()
|
||
waiters.append(future)
|
||
log.debug("Coalesced request: %s", key.split("\n", 1)[0][:60])
|
||
waiting = True
|
||
else:
|
||
self._coalesce[key] = []
|
||
waiting = False
|
||
|
||
if waiting:
|
||
return await future
|
||
|
||
try:
|
||
result = await self._batch_submit(payload)
|
||
except Exception as e:
|
||
async with self._batch_lock:
|
||
waiters = self._coalesce.pop(key, [])
|
||
for f in waiters:
|
||
if not f.done():
|
||
f.set_exception(e)
|
||
raise
|
||
|
||
async with self._batch_lock:
|
||
waiters = self._coalesce.pop(key, [])
|
||
for f in waiters:
|
||
if not f.done():
|
||
f.set_result(result)
|
||
return result
|
||
|
||
async def relay_parallel(self, method: str, url: str,
|
||
headers: dict, body: bytes = b"",
|
||
chunk_size: int = 512 * 1024,
|
||
max_parallel: int = 8,
|
||
max_chunks: int = 256,
|
||
min_size: int = 0) -> bytes:
|
||
"""Relay with parallel range acceleration for large downloads.
|
||
|
||
Strategy:
|
||
1. Send initial GET with Range: bytes=0-<chunk_size-1>
|
||
2. If target returns 206 (supports ranges), fetch remaining
|
||
chunks concurrently via HTTP/2 multiplexing.
|
||
3. If target returns 200 (no range support) or small file,
|
||
return the single response.
|
||
|
||
Since each Apps Script call takes ~2s regardless of payload size,
|
||
we use:
|
||
- 512 KB chunks (fewer relay calls, lower quota pressure)
|
||
- Up to 8 chunks in flight at once via H2 multiplexing
|
||
- Aggregate throughput of ~2 MB per round-trip (~2-3s)
|
||
"""
|
||
if method != "GET" or body:
|
||
return await self.relay(method, url, headers, body)
|
||
|
||
# Probe: first chunk with Range header
|
||
first_resp = await self._range_probe(url, headers, 0, chunk_size - 1)
|
||
|
||
status, resp_hdrs, resp_body = split_raw_response(first_resp)
|
||
|
||
# No range support → return the single response as-is (status 200
|
||
# from the origin). The client sent a plain GET, so 200 is what it
|
||
# expects.
|
||
if status != 206:
|
||
return first_resp
|
||
|
||
# Parse total size from Content-Range: "bytes 0-262143/1048576"
|
||
parsed_range = parse_content_range(resp_hdrs.get("content-range", ""))
|
||
if not parsed_range:
|
||
# Can't parse — downgrade to 200 so the client (which sent a
|
||
# plain GET) doesn't get confused by 206 + Content-Range.
|
||
return self._rewrite_206_to_200(first_resp)
|
||
first_start, first_end, total_size = parsed_range
|
||
first_err = validate_range_response(
|
||
status, resp_hdrs, resp_body, first_start, first_end, total_size,
|
||
)
|
||
if first_start != 0 or first_err:
|
||
return self._rewrite_206_to_200(first_resp)
|
||
if total_size > self._max_response_body_bytes:
|
||
return error_response(
|
||
502,
|
||
"Relay response exceeds cap "
|
||
f"({self._max_response_body_bytes} bytes). "
|
||
"Increase MAX_RESPONSE_BODY_BYTES in src/core/constants.py if your system has enough RAM.",
|
||
)
|
||
if min_size > 0 and total_size < min_size:
|
||
return self._rewrite_206_to_200(first_resp)
|
||
if max_chunks > 0:
|
||
required_chunk_size = max(
|
||
chunk_size,
|
||
(total_size + max_chunks - 1) // max_chunks,
|
||
)
|
||
if required_chunk_size != chunk_size:
|
||
log.info(
|
||
"Parallel download tuning: chunk size raised from %d KB to %d KB "
|
||
"to keep request count under %d",
|
||
chunk_size // 1024,
|
||
required_chunk_size // 1024,
|
||
max_chunks,
|
||
)
|
||
chunk_size = required_chunk_size
|
||
|
||
# Small file: probe already fetched it all. MUST rewrite to 200
|
||
# because the client never sent a Range header — a stray 206 here
|
||
# breaks fetch()/XHR on sites like x.com and Cloudflare challenges.
|
||
if total_size <= chunk_size or len(resp_body) >= total_size:
|
||
return self._rewrite_206_to_200(first_resp)
|
||
|
||
# Calculate remaining ranges
|
||
ranges = []
|
||
start = len(resp_body)
|
||
while start < total_size:
|
||
end = min(start + chunk_size - 1, total_size - 1)
|
||
ranges.append((start, end))
|
||
start = end + 1
|
||
|
||
log.info("Parallel download: %d bytes, %d chunks of %d KB",
|
||
total_size, len(ranges) + 1, chunk_size // 1024)
|
||
|
||
# Concurrency-limited parallel fetch
|
||
sem = asyncio.Semaphore(max_parallel)
|
||
progress_lock = asyncio.Lock()
|
||
completed_chunks = 1 # first range probe already succeeded
|
||
completed_bytes = len(resp_body)
|
||
last_progress_log = time.perf_counter()
|
||
total_chunks = len(ranges) + 1
|
||
total_bytes = total_size
|
||
|
||
async def fetch_range(s, e, max_tries: int = 3):
|
||
nonlocal completed_chunks, completed_bytes, last_progress_log
|
||
async with sem:
|
||
rh_base = dict(headers) if headers else {}
|
||
rh_base["Range"] = f"bytes={s}-{e}"
|
||
payload = self._build_payload("GET", url, rh_base, b"")
|
||
expected = e - s + 1
|
||
last_err = None
|
||
for attempt in range(max_tries):
|
||
try:
|
||
raw = await self._relay_payload_h1(payload)
|
||
chunk_status, chunk_headers, chunk_body = split_raw_response(raw)
|
||
err = validate_range_response(
|
||
chunk_status, chunk_headers, chunk_body,
|
||
s, e, total_size,
|
||
)
|
||
if err is None:
|
||
now = time.perf_counter()
|
||
async with progress_lock:
|
||
completed_chunks += 1
|
||
completed_bytes += len(chunk_body)
|
||
should_log = (
|
||
completed_chunks == total_chunks
|
||
or (now - last_progress_log) >= 5.0
|
||
)
|
||
if should_log:
|
||
elapsed = max(0.001, now - t0)
|
||
speed_bps = completed_bytes / elapsed
|
||
log.info(
|
||
"Parallel download progress: %s [%d/%d chunks]",
|
||
progress_line(
|
||
elapsed=elapsed,
|
||
done=completed_bytes,
|
||
total=total_bytes,
|
||
speed_bytes_per_sec=speed_bps,
|
||
),
|
||
completed_chunks, total_chunks,
|
||
)
|
||
last_progress_log = now
|
||
return chunk_body
|
||
last_err = err
|
||
except Exception as e_:
|
||
last_err = repr(e_)
|
||
log.warning("Range %d-%d retry %d/%d: %s",
|
||
s, e, attempt + 1, max_tries, last_err)
|
||
await asyncio.sleep(0.3 * (attempt + 1))
|
||
raise RuntimeError(
|
||
f"chunk {s}-{e} failed after {max_tries} tries: {last_err}"
|
||
)
|
||
|
||
t0 = asyncio.get_running_loop().time()
|
||
results = await asyncio.gather(
|
||
*[fetch_range(s, e) for s, e in ranges],
|
||
return_exceptions=True,
|
||
)
|
||
elapsed = asyncio.get_running_loop().time() - t0
|
||
|
||
# Assemble full body
|
||
parts = [resp_body]
|
||
for i, r in enumerate(results):
|
||
if isinstance(r, Exception):
|
||
log.error("Range chunk %d failed: %s", i, r)
|
||
return error_response(502, f"Parallel download failed: {r}")
|
||
parts.append(r)
|
||
|
||
full_body = b"".join(parts)
|
||
kbs = (len(full_body) / 1024) / elapsed if elapsed > 0 else 0
|
||
log.info(
|
||
"Parallel download complete: %s",
|
||
progress_line(
|
||
elapsed=elapsed,
|
||
done=len(full_body),
|
||
total=len(full_body),
|
||
speed_bytes_per_sec=kbs * 1024,
|
||
),
|
||
)
|
||
|
||
# Return as 200 OK (client sent a normal GET)
|
||
result = f"HTTP/1.1 200 OK\r\n"
|
||
skip = {"transfer-encoding", "connection", "keep-alive",
|
||
"content-length", "content-encoding", "content-range"}
|
||
for k, v in resp_hdrs.items():
|
||
if k.lower() not in skip:
|
||
result += f"{k}: {v}\r\n"
|
||
result += f"Content-Length: {len(full_body)}\r\n"
|
||
result += "\r\n"
|
||
return result.encode() + full_body
|
||
|
||
async def stream_parallel_download(self, url: str, headers: dict,
|
||
writer,
|
||
*,
|
||
chunk_size: int = 512 * 1024,
|
||
max_parallel: int = 8,
|
||
max_chunks: int = 256,
|
||
min_size: int = 0) -> bool:
|
||
"""Stream a large range-capable download to the client incrementally.
|
||
|
||
Returns False when the target should fall back to the normal relay
|
||
path (for example no range support or the file is too small).
|
||
Returns True once this method has taken ownership of the client
|
||
response, even if the stream later aborts.
|
||
"""
|
||
first_resp = await self._range_probe(url, headers, 0, chunk_size - 1)
|
||
|
||
status, resp_hdrs, resp_body = split_raw_response(first_resp)
|
||
if status != 206:
|
||
log.info(
|
||
"Streaming download fallback: initial probe returned %s for %s",
|
||
status, url[:80],
|
||
)
|
||
return False
|
||
|
||
parsed_range = parse_content_range(resp_hdrs.get("content-range", ""))
|
||
if not parsed_range:
|
||
log.info(
|
||
"Streaming download fallback: missing/invalid Content-Range for %s",
|
||
url[:80],
|
||
)
|
||
return False
|
||
first_start, first_end, total_size = parsed_range
|
||
first_err = validate_range_response(
|
||
status, resp_hdrs, resp_body, first_start, first_end, total_size,
|
||
)
|
||
if first_start != 0 or first_err:
|
||
log.info(
|
||
"Streaming download fallback: invalid first range (%s) for %s",
|
||
first_err or f"start={first_start}",
|
||
url[:80],
|
||
)
|
||
return False
|
||
if min_size > 0 and total_size < min_size:
|
||
log.info(
|
||
"Streaming download fallback: file too small (%d < %d) for %s",
|
||
total_size, min_size, url[:80],
|
||
)
|
||
return False
|
||
if max_chunks > 0:
|
||
required_chunk_size = max(
|
||
chunk_size,
|
||
(total_size + max_chunks - 1) // max_chunks,
|
||
)
|
||
if required_chunk_size != chunk_size:
|
||
log.info(
|
||
"Parallel download tuning: chunk size raised from %d KB to %d KB "
|
||
"to keep request count under %d",
|
||
chunk_size // 1024,
|
||
required_chunk_size // 1024,
|
||
max_chunks,
|
||
)
|
||
chunk_size = required_chunk_size
|
||
|
||
if total_size <= chunk_size or len(resp_body) >= total_size:
|
||
writer.write(self._render_streaming_headers(resp_hdrs, total_size))
|
||
writer.write(resp_body)
|
||
await writer.drain()
|
||
return True
|
||
|
||
ranges = []
|
||
start = len(resp_body)
|
||
while start < total_size:
|
||
end = min(start + chunk_size - 1, total_size - 1)
|
||
ranges.append((start, end))
|
||
start = end + 1
|
||
|
||
log.info("Parallel streaming download: %d bytes, %d chunks of %d KB",
|
||
total_size, len(ranges) + 1, chunk_size // 1024)
|
||
|
||
temp_file = tempfile.TemporaryFile(prefix="mhrvpn_dl_")
|
||
file_lock = asyncio.Lock()
|
||
sem = asyncio.Semaphore(max_parallel)
|
||
cancel_event = asyncio.Event()
|
||
tasks: list[asyncio.Task] = []
|
||
ready = [asyncio.Event() for _ in ranges]
|
||
errors: list[Exception | None] = [None for _ in ranges]
|
||
delivered_chunks = 1
|
||
delivered_bytes = len(resp_body)
|
||
total_chunks = len(ranges) + 1
|
||
last_progress_log = time.perf_counter()
|
||
t0 = time.perf_counter()
|
||
|
||
async def _write_progress(force: bool = False) -> None:
|
||
nonlocal last_progress_log
|
||
now = time.perf_counter()
|
||
if not force and (now - last_progress_log) < 5.0:
|
||
return
|
||
elapsed = max(0.001, now - t0)
|
||
speed_bps = delivered_bytes / elapsed
|
||
log.info(
|
||
"Parallel download progress: %s [%d/%d chunks]",
|
||
progress_line(
|
||
elapsed=elapsed,
|
||
done=delivered_bytes,
|
||
total=total_size,
|
||
speed_bytes_per_sec=speed_bps,
|
||
),
|
||
delivered_chunks, total_chunks,
|
||
)
|
||
last_progress_log = now
|
||
|
||
async def fetch_range(index: int, start_off: int, end_off: int,
|
||
max_tries: int = 3) -> None:
|
||
async with sem:
|
||
base_headers = dict(headers) if headers else {}
|
||
base_headers["Range"] = f"bytes={start_off}-{end_off}"
|
||
payload = self._build_payload("GET", url, base_headers, b"")
|
||
expected = end_off - start_off + 1
|
||
last_err = "unknown"
|
||
try:
|
||
for attempt in range(max_tries):
|
||
if cancel_event.is_set():
|
||
return
|
||
try:
|
||
raw = await self._relay_payload_h1(payload)
|
||
chunk_status, chunk_headers, chunk_body = split_raw_response(raw)
|
||
err = validate_range_response(
|
||
chunk_status, chunk_headers, chunk_body,
|
||
start_off, end_off, total_size,
|
||
)
|
||
if err is None:
|
||
async with file_lock:
|
||
await asyncio.to_thread(
|
||
spool_write, temp_file, start_off, chunk_body,
|
||
)
|
||
ready[index].set()
|
||
return
|
||
last_err = err
|
||
except Exception as exc:
|
||
last_err = repr(exc)
|
||
if cancel_event.is_set():
|
||
return
|
||
log.warning("Range %d-%d retry %d/%d: %s",
|
||
start_off, end_off, attempt + 1, max_tries, last_err)
|
||
await asyncio.sleep(0.3 * (attempt + 1))
|
||
errors[index] = RuntimeError(
|
||
f"chunk {start_off}-{end_off} failed after {max_tries} tries: {last_err}"
|
||
)
|
||
ready[index].set()
|
||
except asyncio.CancelledError:
|
||
raise
|
||
|
||
try:
|
||
writer.write(self._render_streaming_headers(resp_hdrs, total_size))
|
||
writer.write(resp_body)
|
||
await writer.drain()
|
||
|
||
for index, (start_off, end_off) in enumerate(ranges):
|
||
tasks.append(asyncio.create_task(fetch_range(index, start_off, end_off)))
|
||
|
||
for index, (start_off, end_off) in enumerate(ranges):
|
||
await ready[index].wait()
|
||
if errors[index] is not None:
|
||
raise errors[index]
|
||
expected = end_off - start_off + 1
|
||
async with file_lock:
|
||
chunk = await asyncio.to_thread(
|
||
spool_read, temp_file, start_off, expected,
|
||
)
|
||
if len(chunk) != expected:
|
||
raise RuntimeError(
|
||
f"spooled chunk {start_off}-{end_off} was truncated "
|
||
f"({len(chunk)}/{expected} B)"
|
||
)
|
||
writer.write(chunk)
|
||
await writer.drain()
|
||
delivered_chunks += 1
|
||
delivered_bytes += len(chunk)
|
||
await _write_progress(force=(index == len(ranges) - 1))
|
||
|
||
elapsed = max(0.001, time.perf_counter() - t0)
|
||
log.info(
|
||
"Parallel streaming download complete: %s",
|
||
progress_line(
|
||
elapsed=elapsed,
|
||
done=total_size,
|
||
total=total_size,
|
||
speed_bytes_per_sec=total_size / elapsed,
|
||
),
|
||
)
|
||
return True
|
||
except (ConnectionError, BrokenPipeError, TimeoutError) as exc:
|
||
log.info("Parallel download cancelled by client: %s", exc)
|
||
cancel_event.set()
|
||
return True
|
||
except Exception as exc:
|
||
self._mark_stream_download_failure(url, str(exc))
|
||
log.error("Parallel streaming download failed (%s): %s", url[:60], exc)
|
||
cancel_event.set()
|
||
try:
|
||
if not writer.is_closing():
|
||
writer.close()
|
||
except Exception:
|
||
pass
|
||
return True
|
||
finally:
|
||
cancel_event.set()
|
||
for task in tasks:
|
||
task.cancel()
|
||
if tasks:
|
||
await asyncio.gather(*tasks, return_exceptions=True)
|
||
temp_file.close()
|
||
|
||
@staticmethod
|
||
def _rewrite_206_to_200(raw: bytes) -> bytes:
|
||
"""Rewrite a 206 Partial Content response to 200 OK.
|
||
|
||
Used when we probed with a synthetic Range header but the client
|
||
never asked for one. Handing a 206 back to the browser for a plain
|
||
GET breaks XHR/fetch on sites like x.com and Cloudflare challenges
|
||
(they see it as an aborted/partial response). We drop the
|
||
Content-Range header and set Content-Length to the body size.
|
||
"""
|
||
sep = b"\r\n\r\n"
|
||
if sep not in raw:
|
||
return raw
|
||
header_section, body = raw.split(sep, 1)
|
||
lines = header_section.decode(errors="replace").split("\r\n")
|
||
if not lines:
|
||
return raw
|
||
# Replace status line
|
||
first = lines[0]
|
||
if " 206" in first:
|
||
lines[0] = first.replace(" 206 Partial Content", " 200 OK")\
|
||
.replace(" 206", " 200 OK")
|
||
# Drop Content-Range and recalculate Content-Length
|
||
filtered = [lines[0]]
|
||
for ln in lines[1:]:
|
||
low = ln.lower()
|
||
if low.startswith("content-range:"):
|
||
continue
|
||
if low.startswith("content-length:"):
|
||
continue
|
||
filtered.append(ln)
|
||
filtered.append(f"Content-Length: {len(body)}")
|
||
return ("\r\n".join(filtered) + "\r\n\r\n").encode() + body
|
||
|
||
# Headers that must never be forwarded to the upstream server because
|
||
# they expose the user's real IP address or internal network topology.
|
||
_STRIP_HEADERS: frozenset = frozenset({
|
||
"accept-encoding", # Apps Script auto-decompresses gzip only
|
||
"x-forwarded-for", # would leak the client's real IP
|
||
"x-forwarded-host",
|
||
"x-forwarded-proto",
|
||
"x-forwarded-port",
|
||
"x-real-ip", # nginx / CDN header that carries real IP
|
||
"forwarded", # RFC 7239 — same problem
|
||
"via", # reveals intermediate proxy hops
|
||
"proxy-authorization", # never forward credentials to origin
|
||
"proxy-connection",
|
||
})
|
||
|
||
@staticmethod
|
||
def _strip_sabr_quality_tracks(body: bytes) -> bytes:
|
||
"""Strip field-3 (quality-track selection) entries from a SABR
|
||
segment-fetch protobuf.
|
||
|
||
SABR videoplayback POSTs come in two distinct message types:
|
||
|
||
• Segment-fetch — contains field-2 (0x12) top-level entries that
|
||
carry byte-range requests for video/audio segments. Field-3
|
||
(0x1a) entries in these messages are quality-track selectors that
|
||
ask googlevideo to bundle multiple simultaneous quality tracks into
|
||
one response, easily exceeding Apps Script UrlFetchApp's ~10 MB
|
||
buffer → 502. We strip them to force a single-track response.
|
||
|
||
• Session-init — contains field-5 (0x2a) entries and NO field-2
|
||
entries. Field-3 entries in this message type carry essential
|
||
session metadata (language, viewer state, etc.). Stripping them
|
||
corrupts the init handshake → CDN returns 403.
|
||
|
||
We therefore only strip field-3 entries when at least one field-2
|
||
entry is found at the top level (segment-fetch body). For any other
|
||
body type the original bytes are returned unchanged.
|
||
|
||
Only top-level fields are inspected; nested messages are left intact.
|
||
If any unrecognised wire type is encountered the remainder of the
|
||
buffer is copied verbatim so a malformed body is never silently lost.
|
||
"""
|
||
# ── phase 1: single pass — collect all top-level fields ────
|
||
# We need to know whether field 2 exists before deciding to strip.
|
||
# Rather than walking the buffer twice, we accumulate (field_number,
|
||
# seg_start, seg_end) tuples in one pass, then decide what to keep.
|
||
segments: list[tuple[int, int, int]] = [] # (field_number, start, end)
|
||
has_field2 = False
|
||
has_field3 = False
|
||
i = 0
|
||
n = len(body)
|
||
tail_start = n # if we bail early, copy from here
|
||
|
||
while i < n:
|
||
seg_start = i
|
||
# decode varint tag
|
||
tag = 0
|
||
shift = 0
|
||
while i < n:
|
||
b = body[i]; i += 1
|
||
tag |= (b & 0x7F) << shift
|
||
shift += 7
|
||
if not (b & 0x80):
|
||
break
|
||
else:
|
||
tail_start = seg_start
|
||
break
|
||
|
||
field_number = tag >> 3
|
||
wire_type = tag & 0x07
|
||
|
||
# advance i past the field value
|
||
if wire_type == 0: # varint
|
||
while i < n and (body[i] & 0x80):
|
||
i += 1
|
||
if i < n:
|
||
i += 1
|
||
elif wire_type == 1: # 64-bit fixed
|
||
i = min(i + 8, n)
|
||
elif wire_type == 2: # length-delimited
|
||
val_len = 0
|
||
shift = 0
|
||
while i < n:
|
||
b = body[i]; i += 1
|
||
val_len |= (b & 0x7F) << shift
|
||
shift += 7
|
||
if not (b & 0x80):
|
||
break
|
||
i = min(i + val_len, n)
|
||
elif wire_type == 5: # 32-bit fixed
|
||
i = min(i + 4, n)
|
||
else:
|
||
# unknown wire type — bail, copy rest verbatim from seg_start
|
||
tail_start = seg_start
|
||
break
|
||
|
||
if field_number == 2:
|
||
has_field2 = True
|
||
elif field_number == 3:
|
||
has_field3 = True
|
||
segments.append((field_number, seg_start, i))
|
||
|
||
# ── phase 2: decide ────────────────────────────────────────
|
||
# Only strip when this is a segment-fetch body (has field 2).
|
||
# Initialization bodies lack field 2 — field 3 is essential there.
|
||
if not has_field2 or not has_field3:
|
||
return body # nothing to do — return original unchanged
|
||
|
||
out = bytearray()
|
||
for field_number, seg_start, seg_end in segments:
|
||
if field_number != 3:
|
||
out.extend(body[seg_start:seg_end])
|
||
# append any tail bytes that were copied verbatim on early bail
|
||
out.extend(body[tail_start:])
|
||
return bytes(out)
|
||
|
||
def _build_payload(self, method, url, headers, body):
|
||
"""Build the JSON relay payload dict."""
|
||
# Apps Script's UrlFetchApp.fetch() does not accept HEAD or OPTIONS
|
||
# methods — passing either throws and the relay returns 502. Map
|
||
# them to GET on the wire (the upstream still gets the same response
|
||
# body, and HEAD-aware HTTP clients ignore the body anyway since
|
||
# they look at Content-Length / framing). This is the defensive
|
||
# mirror of the same normalisation done in Code.gs.
|
||
upper_method = method.upper() if method else "GET"
|
||
wire_method = "GET" if upper_method in ("HEAD", "OPTIONS") else method
|
||
payload = {
|
||
"m": wire_method,
|
||
"u": url,
|
||
# Let the browser/app see origin redirects and cookies directly.
|
||
"r": False,
|
||
}
|
||
if headers:
|
||
# Strip headers that would leak the user's real IP or expose
|
||
# internal proxy metadata to the upstream destination server.
|
||
# IMPORTANT: always use the filtered dict — never fall back to
|
||
# the original headers even when filt is empty, because that would
|
||
# re-send the very IP-leak headers we just stripped.
|
||
filt = {k: v for k, v in headers.items()
|
||
if k.lower() not in self._STRIP_HEADERS}
|
||
if filt:
|
||
payload["h"] = filt
|
||
if body:
|
||
payload["b"] = base64.b64encode(body).decode()
|
||
ct = headers.get("Content-Type") or headers.get("content-type")
|
||
if ct:
|
||
payload["ct"] = ct
|
||
return payload
|
||
|
||
@classmethod
|
||
def _is_static_asset_url(cls, url: str) -> bool:
|
||
path = urlparse(url).path.lower()
|
||
# Also match versioned paths like /script.js/v3a4b… or /font.woff2/hash
|
||
return any(path.endswith(ext) or f"{ext}/" in path for ext in cls._STATIC_EXTS)
|
||
|
||
@staticmethod
|
||
def _header_value(headers: dict | None, name: str) -> str:
|
||
if not headers:
|
||
return ""
|
||
for key, value in headers.items():
|
||
if key.lower() == name:
|
||
return str(value)
|
||
return ""
|
||
|
||
@classmethod
|
||
def _is_stateful_request(cls, method: str, url: str,
|
||
headers: dict | None, body: bytes) -> bool:
|
||
method = method.upper()
|
||
if method not in {"GET", "HEAD"} or body:
|
||
return True
|
||
|
||
# Static assets are safe to batch in parallel as independent requests.
|
||
is_static = cls._is_static_asset_url(url)
|
||
|
||
if headers and not is_static:
|
||
# Static assets (.css, .js, .woff2, .png, …) are served the same
|
||
# regardless of cookies — browsers always attach cookies but the
|
||
# server doesn't vary static responses on them. Only apply
|
||
# header-based stateful checks to non-static URLs.
|
||
for name in ("cookie", "authorization", "proxy-authorization"):
|
||
if cls._header_value(headers, name):
|
||
return True
|
||
|
||
accept = cls._header_value(headers, "accept").lower()
|
||
if "text/html" in accept:
|
||
return True
|
||
|
||
fetch_mode = cls._header_value(headers, "sec-fetch-mode").lower()
|
||
if fetch_mode == "navigate":
|
||
return True
|
||
|
||
fetch_dest = cls._header_value(headers, "sec-fetch-dest").lower()
|
||
if fetch_dest in {"document", "iframe", "frame"}:
|
||
return True
|
||
|
||
# Non-static JSON/API calls are treated as stateful by default.
|
||
if "application/json" in accept:
|
||
return True
|
||
|
||
return not is_static
|
||
|
||
# ── Batch collector ───────────────────────────────────────────
|
||
|
||
async def _batch_submit(self, payload: dict) -> bytes:
|
||
"""Submit a request to the batch collector. Returns raw HTTP response."""
|
||
# If batching is disabled, retry enabling it after a cooldown.
|
||
if not self._batch_enabled:
|
||
if (
|
||
not self._batch_permanent_disable
|
||
and self._batch_disabled_at > 0
|
||
and (time.time() - self._batch_disabled_at) >= self._batch_cooldown
|
||
):
|
||
self._batch_enabled = True
|
||
log.info(
|
||
"Batch mode re-enabled after %ds cooldown",
|
||
self._batch_cooldown,
|
||
)
|
||
else:
|
||
return await self._relay_with_retry(payload)
|
||
|
||
future = asyncio.get_running_loop().create_future()
|
||
|
||
async with self._batch_lock:
|
||
self._batch_pending.append((payload, future))
|
||
|
||
if len(self._batch_pending) >= self._batch_max:
|
||
# Batch is full — flush now
|
||
batch = self._batch_pending[:]
|
||
self._batch_pending.clear()
|
||
if self._batch_task and not self._batch_task.done():
|
||
self._batch_task.cancel()
|
||
self._batch_task = None
|
||
self._spawn(self._batch_send(batch))
|
||
elif self._batch_task is None or self._batch_task.done():
|
||
# First request in a new batch window — start timer
|
||
self._batch_task = self._spawn(self._batch_timer())
|
||
|
||
return await future
|
||
|
||
async def _batch_timer(self):
|
||
"""Two-tier batch window: 15ms micro + 120ms macro.
|
||
|
||
Single requests (link clicks) get only 15ms delay.
|
||
Burst traffic (page sub-resources, range chunks) gets a 120ms
|
||
window to accumulate, enabling much larger batches.
|
||
"""
|
||
# Tier 1: micro-window — detect if burst or single
|
||
await asyncio.sleep(self._batch_window_micro)
|
||
async with self._batch_lock:
|
||
if len(self._batch_pending) <= 1:
|
||
# Single request — send immediately (only 5ms delay)
|
||
if self._batch_pending:
|
||
batch = self._batch_pending[:]
|
||
self._batch_pending.clear()
|
||
self._batch_task = None
|
||
self._spawn(self._batch_send(batch))
|
||
return
|
||
|
||
# Tier 2: burst detected — wait more to accumulate
|
||
await asyncio.sleep(self._batch_window_macro - self._batch_window_micro)
|
||
async with self._batch_lock:
|
||
if self._batch_pending:
|
||
batch = self._batch_pending[:]
|
||
self._batch_pending.clear()
|
||
self._batch_task = None
|
||
self._spawn(self._batch_send(batch))
|
||
|
||
@staticmethod
|
||
def _split_list(lst: list, n: int) -> list[list]:
|
||
"""Split lst into n roughly-equal contiguous chunks (no empty chunks)."""
|
||
n = min(n, len(lst))
|
||
k, rem = divmod(len(lst), n)
|
||
chunks, start = [], 0
|
||
for i in range(n):
|
||
size = k + (1 if i < rem else 0)
|
||
chunks.append(lst[start:start + size])
|
||
start += size
|
||
return chunks
|
||
|
||
@staticmethod
|
||
def _url_ext(url: str) -> str:
|
||
"""Extract the lowercase file extension from a URL path (no query)."""
|
||
try:
|
||
path = urlparse(url).path
|
||
if "." in path:
|
||
return path.rsplit(".", 1)[-1].lower()
|
||
except Exception:
|
||
pass
|
||
return ""
|
||
|
||
def _make_sub_batches(self, batch: list, n_connections: int) -> list[list]:
|
||
"""Build sub-batches that isolate heavy (binary) from light requests.
|
||
|
||
A 2 kB CSS file and a 400 kB font batched together mean the CSS
|
||
future doesn't resolve until the font finishes downloading at 40 KB/s
|
||
(~15s). By separating heavy files onto their own H2 connection the
|
||
light files resolve in <1s and the browser can continue rendering
|
||
while the large binaries transfer in parallel.
|
||
"""
|
||
if n_connections <= 1:
|
||
return [batch]
|
||
|
||
heavy, light = [], []
|
||
for item in batch:
|
||
url = item[0].get("u", "")
|
||
ext = self._url_ext(url)
|
||
(heavy if ext in self._HEAVY_EXTENSIONS else light).append(item)
|
||
|
||
if not heavy:
|
||
# All light items (CSS, JS, JSON…) — keep as a single batch.
|
||
# Each sub-batch is one Apps Script execution; splitting N small
|
||
# files into N executions wastes N× quota with negligible DPI
|
||
# benefit (small payloads clear the token bucket quickly anyway).
|
||
return [batch]
|
||
if not light:
|
||
# All heavy items — split across connections so each large file
|
||
# gets its own DPI token bucket (parallel throughput).
|
||
return self._split_list(batch, min(n_connections, len(batch)))
|
||
|
||
# Reserve n_connections-1 slots for heavy items (each gets its own
|
||
# throughput budget); give the remaining slot(s) to light items.
|
||
n_heavy_slots = min(n_connections - 1, len(heavy))
|
||
sub_batches = self._split_list(heavy, n_heavy_slots)
|
||
n_light_slots = n_connections - len(sub_batches)
|
||
if n_light_slots > 1 and len(light) >= n_light_slots:
|
||
sub_batches += self._split_list(light, n_light_slots)
|
||
else:
|
||
sub_batches.append(light)
|
||
return [s for s in sub_batches if s]
|
||
|
||
async def _batch_send(self, batch: list):
|
||
"""Send a batch of requests, split across H2 connections for parallel throughput.
|
||
|
||
Iran's DPI shapes per-TCP-connection. A 600 kB response over one
|
||
connection at 40 KB/s takes ~15s. Splitting the same batch across 3
|
||
connections means each carries ~200 kB → ~5s, all in parallel → 3×
|
||
faster wall-clock. Each sub-batch is an independent Apps Script
|
||
fetchAll call on a separate H2 transport.
|
||
"""
|
||
if len(batch) == 1:
|
||
payload, future = batch[0]
|
||
try:
|
||
result = await self._relay_with_retry(payload)
|
||
if not future.done():
|
||
future.set_result(result)
|
||
except Exception as e:
|
||
if not future.done():
|
||
future.set_result(error_response(502, str(e)))
|
||
return
|
||
|
||
# Determine how many live H2 connections to split across.
|
||
n_live = (
|
||
sum(1 for t in self._h2_pool if t.is_connected)
|
||
if self._h2_pool else 0
|
||
)
|
||
n_splits = min(n_live, len(batch)) if (n_live > 1 and self._sub_batch_enabled) else 1
|
||
|
||
if n_splits > 1:
|
||
# Build size-aware sub-batches: heavy files (fonts, images) get
|
||
# their own H2 connection so light files don't wait for them.
|
||
chunks = self._make_sub_batches(batch, n_splits)
|
||
|
||
heavy_count = sum(
|
||
1 for p, _ in batch
|
||
if self._url_ext(p.get("u", "")) in self._HEAVY_EXTENSIONS
|
||
)
|
||
log.info(
|
||
"Batch relay: %d requests (%d heavy+%d light) → %d sub-batches (%s)",
|
||
len(batch), heavy_count, len(batch) - heavy_count,
|
||
len(chunks), "+".join(str(len(c)) for c in chunks),
|
||
)
|
||
|
||
# Wrap each sub-batch relay with timing so slow connections are
|
||
# logged and we can correlate them with DPI shaping events.
|
||
async def _timed_sub_batch(items: list):
|
||
t0 = time.perf_counter()
|
||
result = await self._relay_batch([p for p, _ in items])
|
||
return result, time.perf_counter() - t0
|
||
|
||
chunk_results = await asyncio.gather(
|
||
*[_timed_sub_batch(c) for c in chunks],
|
||
return_exceptions=True,
|
||
)
|
||
|
||
max_dt = 0.0
|
||
for chunk, result in zip(chunks, chunk_results):
|
||
if isinstance(result, Exception):
|
||
log.warning(
|
||
"Sub-batch failed (%s: %s), retrying individually",
|
||
type(result).__name__, result,
|
||
)
|
||
for payload, future in chunk:
|
||
self._spawn(self._relay_fallback(payload, future))
|
||
else:
|
||
items_result, dt = result
|
||
max_dt = max(max_dt, dt)
|
||
if dt > 8.0:
|
||
log.warning(
|
||
"Slow sub-batch: %.1fs for %d items — DPI shaping?",
|
||
dt, len(chunk),
|
||
)
|
||
for (_, future), raw in zip(chunk, items_result):
|
||
if not future.done():
|
||
future.set_result(raw)
|
||
if max_dt > 0:
|
||
log.debug("Batch wall-clock: %.1fs", max_dt)
|
||
return
|
||
|
||
# Single-batch path: H2 unavailable or only one connection live.
|
||
log.info("Batch relay: %d requests", len(batch))
|
||
try:
|
||
results = await self._relay_batch([p for p, _ in batch])
|
||
for (_, future), result in zip(batch, results):
|
||
if not future.done():
|
||
future.set_result(result)
|
||
except Exception as e:
|
||
# Only globally disable batch mode for genuine failures (parse
|
||
# errors, protocol errors). A bare TimeoutError or transient
|
||
# connection drop is recoverable on the very next batch — keeping
|
||
# batch mode disabled for 60s while traffic floods (e.g. a Vercel
|
||
# marketing page with 200+ chunks) collapses every request into
|
||
# its own Apps Script execution and explodes quota usage.
|
||
transient = isinstance(e, (asyncio.TimeoutError, ConnectionError,
|
||
TimeoutError, OSError))
|
||
if transient:
|
||
log.warning(
|
||
"Batch relay transient error (%s: %s) — falling back "
|
||
"individually but keeping batch mode enabled",
|
||
type(e).__name__, e or "(no details)",
|
||
)
|
||
else:
|
||
log.warning(
|
||
"Batch relay failed, disabling batch mode for %ds cooldown. "
|
||
"Error: %s: %s",
|
||
self._batch_cooldown, type(e).__name__, e or "(no details)",
|
||
)
|
||
self._batch_enabled = False
|
||
self._batch_disabled_at = time.time()
|
||
# Fallback: send individually
|
||
tasks = []
|
||
for payload, future in batch:
|
||
tasks.append(self._relay_fallback(payload, future))
|
||
await asyncio.gather(*tasks)
|
||
|
||
async def _relay_fallback(self, payload, future):
|
||
"""Fallback: relay a single request from a failed batch."""
|
||
try:
|
||
result = await self._relay_with_retry(payload)
|
||
if not future.done():
|
||
future.set_result(result)
|
||
except Exception as e:
|
||
if not future.done():
|
||
future.set_result(error_response(502, str(e)))
|
||
|
||
# ── Core relay with retry ─────────────────────────────────────
|
||
|
||
async def _relay_with_retry(self, payload: dict) -> bytes:
|
||
"""Single relay with one retry on failure. Uses H2 if available."""
|
||
attempts = self._retry_attempts_for_payload(payload)
|
||
# Fan-out: race N Apps Script instances when enabled and H2 is up.
|
||
# Cuts tail latency when one container is slow/cold. Only kicks in
|
||
# if multiple script IDs are configured and the H2 transport is live.
|
||
if (attempts > 1
|
||
and self._parallel_relay > 1
|
||
and len(self._script_ids) > 1
|
||
and self._h2_available()):
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
self._relay_fanout(payload), timeout=self._relay_timeout,
|
||
)
|
||
self._record_h2_success()
|
||
return result
|
||
except Exception as e:
|
||
if self._is_h2_transport_error(e):
|
||
self._record_h2_failure(e)
|
||
self._spawn(self._reconnect_pool_members())
|
||
log.debug("Fan-out relay failed (%s), falling back", e)
|
||
# fall through to single-path logic below
|
||
|
||
# Try HTTP/2 first — much faster (multiplexed, no pool checkout)
|
||
if self._h2_available():
|
||
for attempt in range(attempts):
|
||
try:
|
||
result = await asyncio.wait_for(
|
||
self._relay_single_h2(payload), timeout=self._relay_timeout
|
||
)
|
||
self._record_h2_success()
|
||
return result
|
||
except Exception as e:
|
||
is_transport = self._is_h2_transport_error(e)
|
||
if is_transport:
|
||
self._record_h2_failure(e)
|
||
# Spawn background reconnect for any newly-dead transports
|
||
# so future requests find healthy connections.
|
||
self._spawn(self._reconnect_pool_members())
|
||
if attempt < attempts - 1 and self._h2_available():
|
||
log.debug("H2 relay attempt %d failed (%s: %s), retrying",
|
||
attempt + 1, type(e).__name__, e)
|
||
else:
|
||
log.debug(
|
||
"H2 relay failed (%s: %s), falling back to H1",
|
||
type(e).__name__, e,
|
||
)
|
||
break
|
||
|
||
# HTTP/1.1 fallback (pool-based)
|
||
async with self._semaphore:
|
||
for attempt in range(attempts):
|
||
try:
|
||
return await asyncio.wait_for(
|
||
self._relay_single(payload), timeout=self._relay_timeout
|
||
)
|
||
except Exception as e:
|
||
if attempt < attempts - 1:
|
||
log.debug("Relay attempt %d failed (%s: %s), retrying",
|
||
attempt + 1,
|
||
type(e).__name__, e)
|
||
await self._flush_pool()
|
||
else:
|
||
raise
|
||
|
||
async def _relay_fanout(self, payload: dict) -> bytes:
|
||
"""Fire the same relay against N distinct script IDs in parallel.
|
||
|
||
Returns the first successful response; cancels the rest as soon as
|
||
one finishes. Any script that raises or loses the race AND later
|
||
fails individually is blacklisted for SCRIPT_BLACKLIST_TTL.
|
||
"""
|
||
host_key = self._host_key(payload.get("u"))
|
||
sids = self._pick_fanout_sids(host_key)
|
||
if len(sids) <= 1:
|
||
# Nothing to race against (e.g. all others blacklisted)
|
||
return await self._relay_single_h2_with_sid(payload, sids[0])
|
||
|
||
tasks = {
|
||
asyncio.create_task(
|
||
self._relay_single_h2_with_sid(payload, sid)
|
||
): sid
|
||
for sid in sids
|
||
}
|
||
winner_result: bytes | None = None
|
||
winner_exc: BaseException | None = None
|
||
pending = set(tasks.keys())
|
||
try:
|
||
while pending:
|
||
done, pending = await asyncio.wait(
|
||
pending, return_when=asyncio.FIRST_COMPLETED,
|
||
)
|
||
for t in done:
|
||
sid = tasks[t]
|
||
exc = t.exception()
|
||
if exc is None:
|
||
winner_result = t.result()
|
||
return winner_result
|
||
# This racer failed — blacklist and keep waiting for others
|
||
self._blacklist_sid(sid, reason=type(exc).__name__)
|
||
winner_exc = exc
|
||
# All racers failed
|
||
if winner_exc is not None:
|
||
raise winner_exc
|
||
raise RuntimeError("fan-out relay: all racers failed")
|
||
finally:
|
||
for t in pending:
|
||
t.cancel()
|
||
# Drain cancelled tasks so they don't log warnings
|
||
if pending:
|
||
await asyncio.gather(*pending, return_exceptions=True)
|
||
|
||
async def _relay_single_h2(self, payload: dict) -> bytes:
|
||
"""Execute a relay through HTTP/2 multiplexing.
|
||
|
||
Picks a connection from the pool via round-robin so each request
|
||
benefits from its own DPI token bucket.
|
||
"""
|
||
full_payload = dict(payload)
|
||
full_payload["k"] = self.auth_key
|
||
json_body = json.dumps(full_payload).encode()
|
||
|
||
sid = self._script_id_for_key(self._host_key(payload.get("u")))
|
||
path = self._exec_path_for_sid(sid)
|
||
self._record_execution(sid)
|
||
|
||
t0 = time.perf_counter()
|
||
status, headers, body = await (self._pick_h2() or self._h2).request(
|
||
method="POST", path=path, host=self.http_host,
|
||
headers=self._apps_script_headers(),
|
||
body=json_body,
|
||
timeout=self._relay_timeout,
|
||
)
|
||
if log.isEnabledFor(logging.DEBUG):
|
||
log.debug(
|
||
"H2 relay %s [%s] %.0fms (%d bytes)",
|
||
payload.get("m", "?"),
|
||
(payload.get("u") or "")[:60],
|
||
(time.perf_counter() - t0) * 1000.0,
|
||
len(body),
|
||
)
|
||
|
||
return parse_relay_response(body, self._max_response_body_bytes)
|
||
|
||
async def _relay_single_h2_with_sid(self, payload: dict,
|
||
sid: str) -> bytes:
|
||
"""Execute an H2 relay pinned to a specific Apps Script deployment.
|
||
|
||
Used by `_relay_fanout` to race multiple script IDs in parallel.
|
||
Mirrors `_relay_single_h2` but ignores the stable-hash routing.
|
||
"""
|
||
full_payload = dict(payload)
|
||
full_payload["k"] = self.auth_key
|
||
json_body = json.dumps(full_payload).encode()
|
||
|
||
path = self._exec_path_for_sid(sid)
|
||
self._record_execution(sid)
|
||
|
||
status, headers, body = await (self._pick_h2() or self._h2).request(
|
||
method="POST", path=path, host=self.http_host,
|
||
headers=self._apps_script_headers(),
|
||
body=json_body,
|
||
timeout=self._relay_timeout,
|
||
)
|
||
|
||
return parse_relay_response(body, self._max_response_body_bytes)
|
||
|
||
async def _follow_redirects(
|
||
self,
|
||
reader: asyncio.StreamReader,
|
||
writer: asyncio.StreamWriter,
|
||
status: int,
|
||
resp_headers: dict,
|
||
resp_body: bytes,
|
||
original_body: bytes,
|
||
) -> tuple[int, dict, bytes]:
|
||
"""Follow up to 5 HTTP redirects on an existing H1 connection.
|
||
|
||
307/308 preserve the request method and body; all others become
|
||
GET with an empty body (RFC 7231 §6.4).
|
||
"""
|
||
for _ in range(5):
|
||
if status not in (301, 302, 303, 307, 308):
|
||
break
|
||
location = resp_headers.get("location")
|
||
if not location:
|
||
break
|
||
parsed = urlparse(location)
|
||
rpath = parsed.path + ("?" + parsed.query if parsed.query else "")
|
||
if status in (307, 308):
|
||
redirect_method = "POST"
|
||
redirect_body = original_body
|
||
else:
|
||
redirect_method = "GET"
|
||
redirect_body = b""
|
||
request_lines = [
|
||
f"{redirect_method} {rpath} HTTP/1.1",
|
||
f"Host: {parsed.netloc}",
|
||
"Accept: application/json,text/plain,*/*",
|
||
"Accept-Language: en-US,en;q=0.9",
|
||
"Accept-Encoding: gzip",
|
||
"Connection: keep-alive",
|
||
]
|
||
if redirect_body:
|
||
request_lines.append(f"Content-Length: {len(redirect_body)}")
|
||
request = "\r\n".join(request_lines) + "\r\n\r\n"
|
||
writer.write(request.encode() + redirect_body)
|
||
await writer.drain()
|
||
status, resp_headers, resp_body = await read_http_response(
|
||
reader, max_bytes=self._max_response_body_bytes
|
||
)
|
||
return status, resp_headers, resp_body
|
||
|
||
async def _relay_single(self, payload: dict) -> bytes:
|
||
"""Execute a single relay POST → redirect → parse."""
|
||
# Add auth key
|
||
full_payload = dict(payload)
|
||
full_payload["k"] = self.auth_key
|
||
json_body = json.dumps(full_payload).encode()
|
||
|
||
sid = self._script_id_for_key(self._host_key(payload.get("u")))
|
||
path = self._exec_path_for_sid(sid)
|
||
reader, writer, created = await self._acquire()
|
||
|
||
try:
|
||
request = (
|
||
f"POST {path} HTTP/1.1\r\n"
|
||
f"Host: {self.http_host}\r\n"
|
||
f"Content-Type: application/json\r\n"
|
||
f"Accept: application/json,text/plain,*/*\r\n"
|
||
f"Accept-Language: en-US,en;q=0.9\r\n"
|
||
f"Content-Length: {len(json_body)}\r\n"
|
||
f"Accept-Encoding: gzip\r\n"
|
||
f"Connection: keep-alive\r\n"
|
||
f"\r\n"
|
||
)
|
||
writer.write(request.encode() + json_body)
|
||
await writer.drain()
|
||
self._record_execution(sid)
|
||
|
||
status, resp_headers, resp_body = await read_http_response(
|
||
reader, max_bytes=self._max_response_body_bytes
|
||
)
|
||
status, resp_headers, resp_body = await self._follow_redirects(
|
||
reader, writer, status, resp_headers, resp_body, json_body
|
||
)
|
||
|
||
await self._release(reader, writer, created)
|
||
return parse_relay_response(resp_body, self._max_response_body_bytes)
|
||
|
||
except Exception:
|
||
try:
|
||
writer.close()
|
||
except Exception:
|
||
pass
|
||
raise
|
||
|
||
async def _relay_batch(self, payloads: list[dict]) -> list[bytes]:
|
||
"""Send multiple requests in one POST using Apps Script fetchAll."""
|
||
batch_payload = {
|
||
"k": self.auth_key,
|
||
"q": payloads,
|
||
}
|
||
json_body = json.dumps(batch_payload).encode()
|
||
sid = self._script_id_for_key(
|
||
self._host_key(payloads[0].get("u") if payloads else None)
|
||
)
|
||
path = self._exec_path_for_sid(sid)
|
||
|
||
# Try HTTP/2 first. Use the configured relay_timeout: batches can
|
||
# carry the combined response of N requests so they need at least as
|
||
# much time as a single relay. A hardcoded 30s timed out legitimate
|
||
# large-asset bursts and forced batch mode into a 60s cooldown,
|
||
# collapsing every subsequent request into its own Apps Script
|
||
# execution (huge quota burn).
|
||
if self._h2_available():
|
||
batch_timeout = max(self._relay_timeout, 30.0)
|
||
try:
|
||
self._record_execution(sid)
|
||
t0 = time.perf_counter()
|
||
status, headers, body = await asyncio.wait_for(
|
||
(self._pick_h2() or self._h2).request(
|
||
method="POST", path=path, host=self.http_host,
|
||
headers=self._apps_script_headers(),
|
||
body=json_body,
|
||
timeout=batch_timeout,
|
||
),
|
||
timeout=batch_timeout,
|
||
)
|
||
if log.isEnabledFor(logging.DEBUG):
|
||
log.debug(
|
||
"H2 batch %d items: %.0fms (%d bytes)",
|
||
len(payloads),
|
||
(time.perf_counter() - t0) * 1000.0,
|
||
len(body),
|
||
)
|
||
self._record_h2_success()
|
||
return self._parse_batch_body(body, payloads)
|
||
except Exception as e:
|
||
if self._is_h2_transport_error(e):
|
||
self._record_h2_failure(e)
|
||
self._spawn(self._reconnect_pool_members())
|
||
log.debug("H2 batch failed (%s), falling back to H1", e)
|
||
|
||
# HTTP/1.1 fallback
|
||
async with self._semaphore:
|
||
reader, writer, created = await self._acquire()
|
||
try:
|
||
request = (
|
||
f"POST {path} HTTP/1.1\r\n"
|
||
f"Host: {self.http_host}\r\n"
|
||
f"Content-Type: application/json\r\n"
|
||
f"Content-Length: {len(json_body)}\r\n"
|
||
f"Accept-Encoding: gzip\r\n"
|
||
f"Connection: keep-alive\r\n"
|
||
f"\r\n"
|
||
)
|
||
writer.write(request.encode() + json_body)
|
||
await writer.drain()
|
||
self._record_execution(sid)
|
||
|
||
status, resp_headers, resp_body = await read_http_response(
|
||
reader, max_bytes=self._max_response_body_bytes
|
||
)
|
||
status, resp_headers, resp_body = await self._follow_redirects(
|
||
reader, writer, status, resp_headers, resp_body, json_body
|
||
)
|
||
|
||
await self._release(reader, writer, created)
|
||
|
||
except Exception:
|
||
try:
|
||
writer.close()
|
||
except Exception:
|
||
pass
|
||
raise
|
||
|
||
return self._parse_batch_body(resp_body, payloads)
|
||
|
||
def _parse_batch_body(self, resp_body: bytes,
|
||
payloads: list[dict]) -> list[bytes]:
|
||
"""Parse a batch response body into individual results."""
|
||
text = resp_body.decode(errors="replace").strip()
|
||
# Apps Script can wrap JSON inside an HTML shell; reuse the same
|
||
# robust loader used by single-response parsing.
|
||
data = load_relay_json(text)
|
||
if not data:
|
||
raise RuntimeError(f"Bad batch response: {text[:200]}")
|
||
|
||
if "e" in data:
|
||
raise RuntimeError(f"Batch error: {data['e']}")
|
||
|
||
items = data.get("q", [])
|
||
if len(items) != len(payloads):
|
||
raise RuntimeError(
|
||
f"Batch size mismatch: {len(items)} vs {len(payloads)}"
|
||
)
|
||
|
||
results = []
|
||
for item in items:
|
||
results.append(parse_relay_json(item, self._max_response_body_bytes))
|
||
return results
|
||
|