Major refactor on codes (no feature is new, just code refactor)

Co-authored-by: Copilot <copilot@github.com>
This commit is contained in:
Abolfazl
2026-05-02 12:03:59 +03:30
parent bd98098499
commit c5beb51df0
25 changed files with 1950 additions and 1035 deletions
+9 -1
View File
@@ -92,6 +92,14 @@ and generates a strong random password for you. Follow the Apps Script deploymen
instructions in **Step 2** below before running the wizard so you have a instructions in **Step 2** below before running the wizard so you have a
Deployment ID ready. Deployment ID ready.
## Project Structure
- `src/core/` shared modules (config constants, logging, cert install, LAN, scanner)
- `src/proxy/` local proxy runtime (HTTP/SOCKS, MITM, proxy helpers)
- `src/relay/` Apps Script relay runtime (relay engine, parsing, H2, helpers)
- `apps_script/` deployable edge/runtime scripts
- `docs/exit-node/` exit-node deployment guides
After it's running, jump to **Step 5** (browser proxy) and **Step 6** (CA After it's running, jump to **Step 5** (browser proxy) and **Step 6** (CA
certificate). certificate).
@@ -188,7 +196,7 @@ You can deploy any one of these free exit-node templates:
3. Deno Deploy: [`apps_script/deno_deploy.ts`](apps_script/deno_deploy.ts) 3. Deno Deploy: [`apps_script/deno_deploy.ts`](apps_script/deno_deploy.ts)
Full step-by-step deployment guide (all providers): Full step-by-step deployment guide (all providers):
- [EXIT_NODE_DEPLOYMENT.md](EXIT_NODE_DEPLOYMENT.md) - [docs/exit-node/EXIT_NODE_DEPLOYMENT.md](docs/exit-node/EXIT_NODE_DEPLOYMENT.md)
Set the same PSK secret inside the exit-node code (`PSK` constant) and in `config.json`. Set the same PSK secret inside the exit-node code (`PSK` constant) and in `config.json`.
+12 -2
View File
@@ -61,6 +61,16 @@
--- ---
## ساختار پروژه
- `src/core/` ماژول‌های مشترک (ثابت‌ها، لاگ، نصب گواهی، LAN، اسکنر)
- `src/proxy/` هسته پراکسی محلی (HTTP/SOCKS، MITM، ابزارهای پراکسی)
- `src/relay/` هسته رله Apps Script (موتور رله، پارس پاسخ، H2، ابزارها)
- `apps_script/` اسکریپت‌های deploy روی سرویس‌های edge
- `docs/exit-node/` راهنماهای deployment نود خروجی
---
## راه‌اندازی مرحله‌به‌مرحله ## راه‌اندازی مرحله‌به‌مرحله
### مرحله 1: دریافت پروژه ### مرحله 1: دریافت پروژه
@@ -147,8 +157,8 @@ cp config.example.json config.json
3. Deno Deploy: [apps_script/deno_deploy.ts](apps_script/deno_deploy.ts) 3. Deno Deploy: [apps_script/deno_deploy.ts](apps_script/deno_deploy.ts)
راهنمای کامل مرحله‌به‌مرحله برای هر provider: راهنمای کامل مرحله‌به‌مرحله برای هر provider:
- [EXIT_NODE_DEPLOYMENT_FA.md](EXIT_NODE_DEPLOYMENT_FA.md) (فارسی) - [docs/exit-node/EXIT_NODE_DEPLOYMENT_FA.md](docs/exit-node/EXIT_NODE_DEPLOYMENT_FA.md) (فارسی)
- [EXIT_NODE_DEPLOYMENT.md](EXIT_NODE_DEPLOYMENT.md) (انگلیسی) - [docs/exit-node/EXIT_NODE_DEPLOYMENT.md](docs/exit-node/EXIT_NODE_DEPLOYMENT.md) (انگلیسی)
سپس همان secret را هم در کد نود خروجی (`PSK`) و هم در `config.json` یکسان بگذارید. سپس همان secret را هم در کد نود خروجی (`PSK`) و هم در `config.json` یکسان بگذارید.
+14 -18
View File
@@ -14,23 +14,19 @@ import logging
import os import os
import sys import sys
# Project modules live under ./src — put that folder on sys.path so the # Project modules live under ./src — add it to sys.path so package imports
# historical flat imports ("from proxy_server import …") keep working. # like "from proxy.proxy_server import ProxyServer" work from project root.
_SRC_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src") _SRC_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "src")
if _SRC_DIR not in sys.path: if _SRC_DIR not in sys.path:
sys.path.insert(0, _SRC_DIR) sys.path.insert(0, _SRC_DIR)
from cert_installer import install_ca, uninstall_ca, is_ca_trusted from core.cert_installer import install_ca, uninstall_ca, is_ca_trusted
from constants import __version__ from core.constants import __version__
from lan_utils import log_lan_access from core.lan_utils import log_lan_access
from google_ip_scanner import scan_sync from core.google_ip_scanner import scan_sync
from logging_utils import configure as configure_logging, print_banner from core.logging_utils import configure as configure_logging, print_banner
from mitm import CA_CERT_FILE from proxy.mitm import CA_CERT_FILE
from proxy_server import ProxyServer from proxy.proxy_server import ProxyServer
def setup_logging(level_name: str):
configure_logging(level_name)
_PLACEHOLDER_AUTH_KEYS = { _PLACEHOLDER_AUTH_KEYS = {
@@ -111,13 +107,13 @@ def main():
# Handle cert-only commands before loading config so they can run standalone. # Handle cert-only commands before loading config so they can run standalone.
if args.install_cert or args.uninstall_cert: if args.install_cert or args.uninstall_cert:
setup_logging("INFO") configure_logging("INFO")
_log = logging.getLogger("Main") _log = logging.getLogger("Main")
if args.install_cert: if args.install_cert:
_log.info("Installing CA certificate…") _log.info("Installing CA certificate…")
if not os.path.exists(CA_CERT_FILE): if not os.path.exists(CA_CERT_FILE):
from mitm import MITMCertManager from proxy.mitm import MITMCertManager
MITMCertManager() # side-effect: creates ca/ca.crt + ca/ca.key MITMCertManager() # side-effect: creates ca/ca.crt + ca/ca.key
ok = install_ca(CA_CERT_FILE) ok = install_ca(CA_CERT_FILE)
sys.exit(0 if ok else 1) sys.exit(0 if ok else 1)
@@ -219,14 +215,14 @@ def main():
# ── Google IP Scanner ────────────────────────────────────────────────── # ── Google IP Scanner ──────────────────────────────────────────────────
if args.scan: if args.scan:
setup_logging("INFO") configure_logging("INFO")
front_domain = config.get("front_domain", "www.google.com") front_domain = config.get("front_domain", "www.google.com")
_log = logging.getLogger("Main") _log = logging.getLogger("Main")
_log.info(f"Scanning Google IPs (fronting domain: {front_domain})") _log.info(f"Scanning Google IPs (fronting domain: {front_domain})")
ok = scan_sync(front_domain) ok = scan_sync(front_domain)
sys.exit(0 if ok else 1) sys.exit(0 if ok else 1)
setup_logging(config.get("log_level", "INFO")) configure_logging(config.get("log_level", "INFO"))
log = logging.getLogger("Main") log = logging.getLogger("Main")
print_banner(__version__) print_banner(__version__)
@@ -245,7 +241,7 @@ def main():
# Ensure CA file exists before checking / installing it. # Ensure CA file exists before checking / installing it.
# MITMCertManager generates ca/ca.crt on first instantiation. # MITMCertManager generates ca/ca.crt on first instantiation.
if not os.path.exists(CA_CERT_FILE): if not os.path.exists(CA_CERT_FILE):
from mitm import MITMCertManager from proxy.mitm import MITMCertManager
MITMCertManager() # side-effect: creates ca/ca.crt + ca/ca.key MITMCertManager() # side-effect: creates ca/ca.crt + ca/ca.key
# Auto-install MITM CA if not already trusted # Auto-install MITM CA if not already trusted
+694
View File
@@ -0,0 +1,694 @@
"""
Transport protocol & connection benchmark suite.
Tests run against Google's edge IP with SNI fronting. Four suites:
1. Protocol sequential — H1.1 / H2 / H3, one request at a time (apples-to-apples latency)
2. TLS session resumption — cold connect vs warm reconnect using cached session ticket
3. Concurrency — H2 multiplex (N streams on 1 conn) vs H1.1 parallel (N separate conns)
4. IP scan — probe all candidate Google IPs to find the fastest one on this network
Usage:
python scripts/benchmark_transport.py # reads config.json
python scripts/benchmark_transport.py --ip 216.239.38.120 --sni www.google.com
python scripts/benchmark_transport.py --suite protocol # only run suite 1
python scripts/benchmark_transport.py --suite resumption
python scripts/benchmark_transport.py --suite concurrency
python scripts/benchmark_transport.py --suite ipscan
"""
from __future__ import annotations
import argparse
import asyncio
import json
import os
import socket
import ssl
import statistics
import sys
import time
from pathlib import Path
# ── Optional imports ──────────────────────────────────────────────────────
try:
import h2.connection
import h2.config
import h2.events
import h2.settings
H2_AVAILABLE = True
except ImportError:
H2_AVAILABLE = False
try:
import certifi
_CAFILE = certifi.where()
except ImportError:
_CAFILE = None
try:
import aioquic.asyncio as quic_asyncio
import aioquic.h3.connection as h3c
import aioquic.h3.events as h3e
import aioquic.quic.configuration as quic_cfg
import aioquic.quic.events as quic_events
H3_AVAILABLE = True
except ImportError:
H3_AVAILABLE = False
# ── TLS context helpers ───────────────────────────────────────────────────
def _make_tls_ctx(alpn: list[str]) -> ssl.SSLContext:
ctx = ssl.create_default_context()
if _CAFILE:
try:
ctx.load_verify_locations(cafile=_CAFILE)
except Exception:
pass
ctx.set_alpn_protocols(alpn)
return ctx
# ── HTTP/1.1 probe ────────────────────────────────────────────────────────
async def _probe_h1(host_ip: str, sni: str, path: str, timeout: float) -> float:
"""Return elapsed seconds for one H1.1 GET. Raises on error."""
ctx = _make_tls_ctx(["http/1.1"])
raw = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
raw.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
raw.setblocking(False)
t0 = time.perf_counter()
loop = asyncio.get_running_loop()
await asyncio.wait_for(loop.sock_connect(raw, (host_ip, 443)), timeout=timeout)
reader, writer = await asyncio.wait_for(
asyncio.open_connection(ssl=ctx, server_hostname=sni, sock=raw),
timeout=timeout,
)
req = (
f"GET {path} HTTP/1.1\r\n"
f"Host: {sni}\r\n"
"Accept: */*\r\n"
"Connection: close\r\n"
"\r\n"
).encode()
writer.write(req)
await asyncio.wait_for(writer.drain(), timeout=timeout)
resp = b""
while True:
chunk = await asyncio.wait_for(reader.read(4096), timeout=timeout)
if not chunk:
break
resp += chunk
if b"\r\n\r\n" in resp:
break
writer.close()
elapsed = time.perf_counter() - t0
if not resp.startswith(b"HTTP/"):
raise RuntimeError(f"Unexpected response: {resp[:60]!r}")
return elapsed
# ── HTTP/2 probe ──────────────────────────────────────────────────────────
async def _probe_h2_fresh(host_ip: str, sni: str, path: str, timeout: float) -> float:
"""One H2 GET on a NEW connection each time (apples-to-apples vs H1)."""
if not H2_AVAILABLE:
raise RuntimeError("h2 not installed")
ctx = _make_tls_ctx(["h2", "http/1.1"])
raw = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
raw.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
raw.setblocking(False)
t0 = time.perf_counter()
loop = asyncio.get_running_loop()
await asyncio.wait_for(loop.sock_connect(raw, (host_ip, 443)), timeout=timeout)
reader, writer = await asyncio.wait_for(
asyncio.open_connection(ssl=ctx, server_hostname=sni, sock=raw),
timeout=timeout,
)
ssl_obj = writer.get_extra_info("ssl_object")
negotiated = ssl_obj.selected_alpn_protocol() if ssl_obj else None
if negotiated != "h2":
writer.close()
raise RuntimeError(f"H2 ALPN failed (got {negotiated!r})")
cfg = h2.config.H2Configuration(client_side=True, header_encoding="utf-8")
conn = h2.connection.H2Connection(cfg)
conn.initiate_connection()
writer.write(conn.data_to_send(65535))
await writer.drain()
stream_id = conn.get_next_available_stream_id()
conn.send_headers(stream_id, [
(":method", "GET"),
(":path", path),
(":scheme", "https"),
(":authority", sni),
("accept", "*/*"),
], end_stream=True)
writer.write(conn.data_to_send(65535))
await asyncio.wait_for(writer.drain(), timeout=timeout)
headers_done = False
while not headers_done:
raw_data = await asyncio.wait_for(reader.read(65535), timeout=timeout)
if not raw_data:
break
events = conn.receive_data(raw_data)
writer.write(conn.data_to_send(65535))
await writer.drain()
for ev in events:
if isinstance(ev, (h2.events.ResponseReceived, h2.events.StreamEnded,
h2.events.DataReceived)):
if isinstance(ev, h2.events.ResponseReceived) and ev.stream_id == stream_id:
headers_done = True
writer.close()
return time.perf_counter() - t0
# ── HTTP/3 (QUIC) probe ───────────────────────────────────────────────────
class _H3ProbeProtocol(quic_asyncio.QuicConnectionProtocol):
"""Minimal aioquic protocol that sends one H3 GET and captures the result."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._h3: h3c.H3Connection | None = None
self._done: asyncio.Future[float] = asyncio.get_event_loop().create_future()
self._t0: float = time.perf_counter()
self._stream_id: int | None = None
def quic_event_received(self, event):
if isinstance(event, quic_events.HandshakeCompleted):
self._h3 = h3c.H3Connection(self._quic, enable_webtransport=False)
if self._h3 is None:
return
for h3ev in self._h3.handle_event(event):
if isinstance(h3ev, h3e.HeadersReceived):
if not self._done.done():
self._done.set_result(time.perf_counter() - self._t0)
elif isinstance(h3ev, h3e.DataReceived):
pass # don't need body
def send_request(self, sni: str, path: str):
self._stream_id = self._quic.get_next_available_stream_id()
self._h3.send_headers(
stream_id=self._stream_id,
headers=[
(b":method", b"GET"),
(b":path", path.encode()),
(b":scheme", b"https"),
(b":authority", sni.encode()),
(b"accept", b"*/*"),
],
end_stream=True,
)
self.transmit()
async def _h3_inner(host_ip: str, sni: str, path: str, timeout: float) -> float:
cfg = quic_cfg.QuicConfiguration(
is_client=True,
server_name=sni,
alpn_protocols=h3c.H3_ALPN,
verify_mode=ssl.CERT_REQUIRED,
)
if _CAFILE:
try:
cfg.load_verify_locations(_CAFILE)
except Exception:
pass
t0 = time.perf_counter()
async with quic_asyncio.connect(
host_ip,
443,
configuration=cfg,
create_protocol=_H3ProbeProtocol,
) as proto:
proto._t0 = t0
proto.send_request(sni, path)
return await proto._done
async def _probe_h3(host_ip: str, sni: str, path: str, timeout: float) -> float:
if not H3_AVAILABLE:
raise RuntimeError("aioquic not installed")
# QUIC uses UDP. Wrap the ENTIRE connect+request in wait_for so a
# network that silently drops UDP packets doesn't stall indefinitely.
h3_timeout = min(timeout, 5.0)
try:
return await asyncio.wait_for(_h3_inner(host_ip, sni, path, h3_timeout), timeout=h3_timeout)
except asyncio.TimeoutError:
raise TimeoutError(f"QUIC/UDP timed out after {h3_timeout:.1f}s — UDP likely blocked or no H3 support")
except Exception as exc:
raise RuntimeError(f"{type(exc).__name__}: {exc or 'no detail'}")
# ── Runner ────────────────────────────────────────────────────────────────
async def _run_protocol(
name: str,
probe,
host_ip: str,
sni: str,
path: str,
n: int,
timeout: float,
) -> dict:
times: list[float] = []
errors = 0
for i in range(n):
try:
t = await probe(host_ip, sni, path, timeout)
times.append(t)
except Exception as exc:
errors += 1
desc = str(exc) or type(exc).__name__
print(f" [{name}] request {i+1}/{n} FAILED: {desc}")
# If the first 3 all failed, give up early to avoid wasting time.
if errors >= 3 and not times:
print(f" [{name}] 3 consecutive failures with no success — aborting protocol test")
break
await asyncio.sleep(0.05) # tiny gap between probes
return {"name": name, "times": times, "errors": errors, "n": n}
def _print_result(r: dict):
name = r["name"]
times = r["times"]
errors = r["errors"]
n = r["n"]
ok = len(times)
if not times:
print(f" {name:10s} NO SUCCESSFUL REQUESTS (errors={errors}/{n})")
return
mn = min(times) * 1000
mx = max(times) * 1000
avg = statistics.mean(times) * 1000
med = statistics.median(times) * 1000
p95 = sorted(times)[int(len(times) * 0.95)] * 1000
print(
f" {name:10s} "
f"ok={ok}/{n} "
f"min={mn:6.1f}ms "
f"avg={avg:6.1f}ms "
f"med={med:6.1f}ms "
f"p95={p95:6.1f}ms "
f"max={mx:6.1f}ms "
f"errors={errors}"
)
async def main(host_ip: str, sni: str, path: str, n: int, timeout: float,
suite: str = "all"):
print(f"\nBenchmark target → {host_ip}:443 SNI={sni} path={path}")
print("=" * 80)
run_all = suite == "all"
# ── Suite 1: Protocol sequential ──────────────────────────────────────
if run_all or suite == "protocol":
print("\n── Suite 1: Protocol sequential latency ──────────────────────────────")
print(f" {n} sequential requests per protocol\n")
protocols: list[tuple[str, object]] = [("HTTP/1.1", _probe_h1)]
if H2_AVAILABLE:
protocols.append(("HTTP/2", _probe_h2_fresh))
else:
print(" [HTTP/2] skipped — pip install h2")
if H3_AVAILABLE:
protocols.append(("HTTP/3", _probe_h3))
else:
print(" [HTTP/3] skipped — pip install aioquic")
results = []
for name, probe in protocols:
print(f" Running {name}...")
r = await _run_protocol(name, probe, host_ip, sni, path, n, timeout)
results.append(r)
print()
for r in results:
_print_result(r)
valid = [r for r in results if r["times"]]
if len(valid) > 1:
best = min(valid, key=lambda r: statistics.median(r["times"]))
print(f"\n Best median: {best['name']}")
h1r = next((r for r in valid if r["name"] == "HTTP/1.1"), None)
h2r = next((r for r in valid if r["name"] == "HTTP/2"), None)
h3r = next((r for r in valid if r["name"] == "HTTP/3"), None)
if h2r and h1r:
g = (statistics.median(h1r["times"]) - statistics.median(h2r["times"])) \
/ statistics.median(h1r["times"]) * 100
print(f" H2 vs H1.1: {g:+.1f}%")
if h3r and h2r:
g = (statistics.median(h2r["times"]) - statistics.median(h3r["times"])) \
/ statistics.median(h2r["times"]) * 100
print(f" H3 vs H2: {g:+.1f}%")
# ── Suite 2: TLS session resumption ───────────────────────────────────
if run_all or suite == "resumption":
print("\n── Suite 2: TLS session resumption ───────────────────────────────────")
print(" Measures cost of cold TLS handshake vs warm reconnect with session ticket\n")
await _suite_resumption(host_ip, sni, path, timeout, rounds=8)
# ── Suite 3: Concurrency ──────────────────────────────────────────────
if run_all or suite == "concurrency":
print("\n── Suite 3: Concurrency — H2 multiplex vs H1.1 parallel ─────────────")
print(f" {n} concurrent requests fired simultaneously\n")
await _suite_concurrency(host_ip, sni, path, timeout, n=n)
# ── Suite 4: IP scan ──────────────────────────────────────────────────
if run_all or suite == "ipscan":
print("\n── Suite 4: Google edge IP latency scan ──────────────────────────────")
print(" H1.1 probe to all candidate IPs — find the fastest one on this network\n")
await _suite_ipscan(sni, path, timeout)
print("\n" + "=" * 80)
print("Done.")
# ── Suite 2: TLS session resumption ──────────────────────────────────────
async def _tls_connect_time(host_ip: str, sni: str, timeout: float,
ctx: ssl.SSLContext | None = None) -> tuple[float, ssl.SSLContext]:
"""Connect with TLS and return (elapsed, ctx). ctx is reused for warm tests."""
if ctx is None:
ctx = _make_tls_ctx(["h2", "http/1.1"])
raw = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
raw.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
raw.setblocking(False)
loop = asyncio.get_running_loop()
t0 = time.perf_counter()
await asyncio.wait_for(loop.sock_connect(raw, (host_ip, 443)), timeout=timeout)
reader, writer = await asyncio.wait_for(
asyncio.open_connection(ssl=ctx, server_hostname=sni, sock=raw),
timeout=timeout,
)
elapsed = time.perf_counter() - t0
# Send minimal request so the server doesn't RST the idle connection
writer.write(f"GET /generate_204 HTTP/1.1\r\nHost: {sni}\r\nConnection: close\r\n\r\n".encode())
await asyncio.wait_for(writer.drain(), timeout=timeout)
try:
await asyncio.wait_for(reader.read(256), timeout=timeout)
except Exception:
pass
writer.close()
return elapsed, ctx
async def _suite_resumption(host_ip: str, sni: str, path: str,
timeout: float, rounds: int):
cold_times: list[float] = []
warm_times: list[float] = []
# cold: fresh SSLContext each time — no session ticket reuse
print(" Cold connects (new TLS context each time)...")
for _ in range(rounds):
try:
t, _ = await _tls_connect_time(host_ip, sni, timeout, ctx=None)
cold_times.append(t * 1000)
except Exception as exc:
print(f" FAILED: {exc}")
await asyncio.sleep(0.1)
# warm: reuse same SSLContext — OpenSSL caches and reuses TLS 1.3 session ticket
print(" Warm reconnects (same TLS context, session ticket reuse)...")
warm_ctx = _make_tls_ctx(["h2", "http/1.1"])
for _ in range(rounds):
try:
t, warm_ctx = await _tls_connect_time(host_ip, sni, timeout, ctx=warm_ctx)
warm_times.append(t * 1000)
except Exception as exc:
print(f" FAILED: {exc}")
await asyncio.sleep(0.1)
def _fmt(times: list[float]) -> str:
if not times:
return "no data"
return (f"min={min(times):.1f}ms avg={statistics.mean(times):.1f}ms "
f"med={statistics.median(times):.1f}ms max={max(times):.1f}ms")
print(f"\n Cold ({len(cold_times)}/{rounds} ok): {_fmt(cold_times)}")
print(f" Warm ({len(warm_times)}/{rounds} ok): {_fmt(warm_times)}")
if cold_times and warm_times:
saving = statistics.median(cold_times) - statistics.median(warm_times)
pct = saving / statistics.median(cold_times) * 100
if saving > 5:
print(f"\n Session ticket saves ~{saving:.1f}ms ({pct:.1f}%) per reconnect")
print(" → The H2 transport already reuses one long-lived connection, so this")
print(" saving only applies when the connection drops and must reconnect.")
else:
print(f"\n Resumption saving: {saving:.1f}ms ({pct:.1f}%) — negligible on this network")
print(" → Google may be issuing short-lived tickets, or RTT already dominates.")
# ── Suite 3: Concurrency ──────────────────────────────────────────────────
async def _h2_concurrent(host_ip: str, sni: str, path: str,
timeout: float, n: int) -> tuple[float, int]:
"""
Fire N H2 streams concurrently on ONE persistent connection.
Returns (wall_time_for_all, successful_count).
"""
if not H2_AVAILABLE:
raise RuntimeError("h2 not installed")
ctx = _make_tls_ctx(["h2", "http/1.1"])
raw = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
raw.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
raw.setblocking(False)
loop = asyncio.get_running_loop()
await asyncio.wait_for(loop.sock_connect(raw, (host_ip, 443)), timeout=timeout)
reader, writer = await asyncio.wait_for(
asyncio.open_connection(ssl=ctx, server_hostname=sni, sock=raw),
timeout=timeout,
)
ssl_obj = writer.get_extra_info("ssl_object")
if not ssl_obj or ssl_obj.selected_alpn_protocol() != "h2":
writer.close()
raise RuntimeError("H2 ALPN not negotiated")
cfg = h2.config.H2Configuration(client_side=True, header_encoding="utf-8")
conn = h2.connection.H2Connection(cfg)
conn.initiate_connection()
conn.increment_flow_control_window(2 ** 24 - 65535)
conn.update_settings({
h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 8 * 1024 * 1024,
h2.settings.SettingCodes.ENABLE_PUSH: 0,
})
writer.write(conn.data_to_send(65535))
await writer.drain()
# Track per-stream completion
stream_done: dict[int, asyncio.Event] = {}
stream_ids = []
for _ in range(n):
sid = conn.get_next_available_stream_id()
conn.send_headers(sid, [
(":method", "GET"), (":path", path),
(":scheme", "https"), (":authority", sni), ("accept", "*/*"),
], end_stream=True)
stream_ids.append(sid)
stream_done[sid] = asyncio.Event()
writer.write(conn.data_to_send(65535))
await writer.drain()
t0 = time.perf_counter()
done_count = 0
deadline = t0 + timeout
while done_count < n and time.perf_counter() < deadline:
try:
raw_data = await asyncio.wait_for(
reader.read(65535),
timeout=max(0.1, deadline - time.perf_counter()),
)
except asyncio.TimeoutError:
break
if not raw_data:
break
events = conn.receive_data(raw_data)
writer.write(conn.data_to_send(65535))
await writer.drain()
for ev in events:
if isinstance(ev, (h2.events.ResponseReceived, h2.events.StreamEnded)):
sid = ev.stream_id
if sid in stream_done and not stream_done[sid].is_set():
if isinstance(ev, h2.events.ResponseReceived):
stream_done[sid].set()
done_count += 1
elif isinstance(ev, h2.events.DataReceived):
conn.acknowledge_received_data(ev.flow_controlled_length, ev.stream_id)
writer.write(conn.data_to_send(65535))
await writer.drain()
wall = time.perf_counter() - t0
writer.close()
return wall, done_count
async def _h1_parallel(host_ip: str, sni: str, path: str,
timeout: float, n: int) -> tuple[float, int]:
"""Fire N H1.1 requests in parallel, each on its own TCP+TLS connection."""
t0 = time.perf_counter()
tasks = [asyncio.create_task(_probe_h1(host_ip, sni, path, timeout)) for _ in range(n)]
results = await asyncio.gather(*tasks, return_exceptions=True)
wall = time.perf_counter() - t0
ok = sum(1 for r in results if isinstance(r, float))
return wall, ok
async def _suite_concurrency(host_ip: str, sni: str, path: str,
timeout: float, n: int):
concur_levels = sorted({4, 8, min(16, n), min(n, 20)})
print(f" {'Level':>5} {'H2 mux wall':>14} {'H1.1 parallel wall':>18} {'speedup':>8}")
print(f" {'-----':>5} {'----------':>14} {'----------------':>18} {'-------':>8}")
for level in concur_levels:
h2_wall = h2_ok = h1_wall = h1_ok = None
h2_err = h1_err = None
if H2_AVAILABLE:
try:
h2_wall, h2_ok = await _h2_concurrent(host_ip, sni, path, timeout, level)
except Exception as exc:
h2_err = str(exc) or type(exc).__name__
try:
h1_wall, h1_ok = await _h1_parallel(host_ip, sni, path, timeout, level)
except Exception as exc:
h1_err = str(exc) or type(exc).__name__
h2_str = f"{h2_wall*1000:6.0f}ms ({h2_ok}/{level})" if h2_wall is not None else f"FAIL: {h2_err}"
h1_str = f"{h1_wall*1000:6.0f}ms ({h1_ok}/{level})" if h1_wall is not None else f"FAIL: {h1_err}"
if h2_wall and h1_wall and h1_wall > 0:
speedup = f"{h1_wall / h2_wall:+.2f}x"
else:
speedup = "n/a"
print(f" {level:>5} {h2_str:>14} {h1_str:>18} {speedup:>8}")
await asyncio.sleep(0.2)
print()
print(" Interpretation:")
print(" - H2 mux fires all streams on ONE TLS connection — lower overhead at scale")
print(" - H1.1 parallel opens N separate connections — higher per-connection TLS cost")
print(" - Speedup > 1.0x means H2 mux completed all requests in less wall time")
# ── Suite 4: IP scan ──────────────────────────────────────────────────────
_CANDIDATE_IPS = (
"216.239.32.120", "216.239.34.120", "216.239.36.120", "216.239.38.120",
"142.250.80.142", "142.250.80.138", "142.250.179.110", "142.250.185.110",
"142.250.184.206", "142.250.190.238", "142.250.191.78",
"172.217.1.206", "172.217.14.206", "172.217.16.142", "172.217.22.174",
"172.217.164.110","172.217.168.206","172.217.169.206",
"34.107.221.82",
"142.251.32.110", "142.251.33.110", "142.251.46.206", "142.251.46.238",
"142.250.80.170", "142.250.72.206", "142.250.64.206", "142.250.72.110",
)
async def _probe_ip(ip: str, sni: str, path: str, timeout: float) -> tuple[str, float | None, str]:
"""Return (ip, median_ms_or_None, note)."""
times = []
for _ in range(3):
try:
t = await _probe_h1(ip, sni, path, timeout)
times.append(t * 1000)
except Exception:
pass
await asyncio.sleep(0.03)
if not times:
return ip, None, "unreachable"
med = statistics.median(times)
return ip, med, ""
async def _suite_ipscan(sni: str, path: str, timeout: float):
ip_timeout = min(timeout, 5.0)
print(f" Probing {len(_CANDIDATE_IPS)} candidate IPs (3 requests each, {ip_timeout:.0f}s cap)...\n")
# Run all probes concurrently — they're independent H1.1 connects
tasks = [asyncio.create_task(_probe_ip(ip, sni, path, ip_timeout))
for ip in _CANDIDATE_IPS]
raw_results = await asyncio.gather(*tasks)
reachable = [(ip, med, note) for ip, med, note in raw_results if med is not None]
dead = [(ip, med, note) for ip, med, note in raw_results if med is None]
reachable.sort(key=lambda x: x[1])
print(f" {'IP':>18} {'median':>9} note")
print(f" {'--':>18} {'------':>9} ----")
for i, (ip, med, _) in enumerate(reachable):
tag = " ← fastest" if i == 0 else (" ← 2nd" if i == 1 else "")
print(f" {ip:>18} {med:7.1f}ms{tag}")
if dead:
print(f"\n Unreachable ({len(dead)}): {', '.join(ip for ip, *_ in dead)}")
if reachable:
best_ip, best_med, _ = reachable[0]
print(f"\n Fastest IP: {best_ip} (median {best_med:.1f}ms)")
print(f' Set in config.json: "google_ip": "{best_ip}"')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Transport benchmark suite")
parser.add_argument("--ip", help="Google edge IP (default: from config.json)")
parser.add_argument("--sni", default="www.google.com", help="SNI hostname")
parser.add_argument("--path", default="/generate_204", help="Request path")
parser.add_argument("--n", type=int, default=15, help="Requests per protocol")
parser.add_argument("--timeout", type=float, default=10.0, help="Per-request timeout (s)")
parser.add_argument(
"--suite",
choices=["all", "protocol", "resumption", "concurrency", "ipscan"],
default="all",
help="Which benchmark suite to run (default: all)",
)
args = parser.parse_args()
host_ip = args.ip
if not host_ip:
cfg_path = Path(__file__).parent.parent / "config.json"
if cfg_path.exists():
with open(cfg_path) as f:
data = json.load(f)
host_ip = data.get("google_ip", "216.239.38.120")
print(f"Using google_ip from config.json: {host_ip}")
else:
host_ip = "216.239.38.120"
print(f"config.json not found, using default: {host_ip}")
asyncio.run(main(
host_ip=host_ip,
sni=args.sni,
path=args.path,
n=args.n,
timeout=args.timeout,
suite=args.suite,
))
-69
View File
@@ -1,69 +0,0 @@
#!/usr/bin/env python3
from __future__ import annotations
import hashlib
import os
import re
import shutil
import tarfile
import zipfile
from pathlib import Path
def _read_version(root: Path) -> str:
constants_py = (root / "src" / "constants.py").read_text(encoding="utf-8")
m = re.search(r'__version__\s*=\s*"([^"]+)"', constants_py)
return m.group(1) if m else "0.0.0"
def main() -> int:
root = Path(".").resolve()
target = os.environ.get("TARGET", "")
if not target:
raise SystemExit("TARGET environment variable is required")
version = _read_version(root)
binary_name = "MasterHttpRelayVPN.exe" if os.name == "nt" else "MasterHttpRelayVPN"
binary_path = root / "dist" / binary_name
if not binary_path.exists():
raise SystemExit(f"binary not found: {binary_path}")
bundle_name = f"MasterHttpRelayVPN-{version}-{target}"
bundle_root = root / "package" / bundle_name
if bundle_root.exists():
shutil.rmtree(bundle_root)
bundle_root.mkdir(parents=True, exist_ok=True)
config_example = root / "config.example.json"
if not config_example.exists():
raise SystemExit(f"missing config.example.json: {config_example}")
shutil.copy2(config_example, bundle_root / "config.example.json")
shutil.copy2(binary_path, bundle_root / binary_name)
if os.name != "nt":
(bundle_root / binary_name).chmod(0o755)
release_dir = root / "release-assets"
release_dir.mkdir(parents=True, exist_ok=True)
if target.startswith("windows"):
archive = release_dir / f"{bundle_name}.zip"
with zipfile.ZipFile(archive, "w", compression=zipfile.ZIP_DEFLATED) as zf:
for path in bundle_root.rglob("*"):
zf.write(path, path.relative_to(bundle_root.parent))
else:
archive = release_dir / f"{bundle_name}.tar.gz"
with tarfile.open(archive, "w:gz") as tf:
tf.add(bundle_root, arcname=bundle_name)
digest = hashlib.sha256(archive.read_bytes()).hexdigest()
(release_dir / f"{archive.name}.sha256").write_text(
f"{digest} {archive.name}\n",
encoding="utf-8",
)
print(f"Created {archive}")
return 0
if __name__ == "__main__":
raise SystemExit(main())
+6
View File
@@ -0,0 +1,6 @@
from .constants import *
from .codec import *
from .logging_utils import *
from .cert_installer import *
from .lan_utils import *
from .google_ip_scanner import *
@@ -5,7 +5,7 @@ Supports: Windows, macOS, Linux (Debian/Ubuntu, RHEL/Fedora/CentOS, Arch).
Also attempts to install into Firefox's NSS certificate store when found. Also attempts to install into Firefox's NSS certificate store when found.
Usage: Usage:
from cert_installer import install_ca, is_ca_trusted from core.cert_installer import install_ca, is_ca_trusted
install_ca("/path/to/ca.crt", cert_name="MasterHttpRelayVPN") install_ca("/path/to/ca.crt", cert_name="MasterHttpRelayVPN")
""" """
+3
View File
@@ -30,6 +30,9 @@ except ImportError: # pragma: no cover
_ZSTD_DCTX = None _ZSTD_DCTX = None
__all__ = ["supported_encodings", "has_brotli", "has_zstd", "decode"]
def supported_encodings() -> str: def supported_encodings() -> str:
"""Value for Accept-Encoding that this relay can actually decode.""" """Value for Accept-Encoding that this relay can actually decode."""
codecs = ["gzip", "deflate"] codecs = ["gzip", "deflate"]
@@ -15,7 +15,7 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
from constants import CANDIDATE_IPS, GOOGLE_SCANNER_TIMEOUT, GOOGLE_SCANNER_CONCURRENCY from .constants import CANDIDATE_IPS, GOOGLE_SCANNER_TIMEOUT, GOOGLE_SCANNER_CONCURRENCY
log = logging.getLogger("Scanner") log = logging.getLogger("Scanner")
@@ -227,44 +227,43 @@ def _install_asyncio_noise_filter() -> None:
def print_banner(version: str, *, stream=None) -> None: def print_banner(version: str, *, stream=None) -> None:
"""Print a polished startup banner with color fallbacks.""" """Print an ASCII startup banner with color fallbacks."""
stream = stream or sys.stderr stream = stream or sys.stderr
color = _supports_color(stream) color = _supports_color(stream)
def c(code: str) -> str: def c(code: str) -> str:
return code if color else "" return code if color else ""
title = "MasterHttpRelayVPN" art = [
subtitle = "Domain-Fronted Apps Script Relay" " __ __ _ ____ _____ _____ ____ ",
version_tag = f"v{version}" "| \\/ | / \\ / ___|_ _| ____| _ \\ ",
"| |\\/| | / _ \\ \\___ \\ | | | _| | |_) |",
left = f" {title} " "| | | |/ ___ \\ ___) || | | |___| _ < ",
center = f" {subtitle} " "|_| |_/_/ \\_\\____/ |_| |_____|_| \\_\\",
right = f" {version_tag} " " _ _ _____ _____ ____ ____ _____ _ _ __ __",
inner_width = max(68, len(left) + len(center) + len(right) + 2) " | | | |_ _|_ _| _ \\ | _ \\| ____| | / \\\\ \\ / /",
" | |_| | | | | | | |_) | | |_) | _| | | / _ \\\\ V / ",
gap = inner_width - (len(left) + len(center) + len(right)) " | _ | | | | | | __/ | _ <| |___| |___ / ___ \\| | ",
left_gap = gap // 2 " |_| |_| |_| |_| |_| |_| \\_\\_____|_____/_/ \\_\\_| ",
right_gap = gap - left_gap ]
version_line = f"Version {version}"
top = "" + ("" * inner_width) + "" link = "https://github.com/masterking32/MasterHttpRelayVPN"
mid = "" + left + (" " * left_gap) + center + (" " * right_gap) + right + "" width = max(max(len(line) for line in art), len(version_line), len(link))
bot = "" + ("" * inner_width) + "" rule = "=" * width
if color: if color:
top = f"{DIM}{FG_GRAY}{top}{RESET}" print(f"{DIM}{FG_GRAY}{rule}{RESET}", file=stream)
bot = f"{DIM}{FG_GRAY}{bot}{RESET}" for line in art:
mid = ( print(f"{BOLD}{FG_CYAN}{line.center(width)}{RESET}", file=stream)
f"{DIM}{FG_GRAY}{RESET}" print(f"{FG_GRAY}{version_line.center(width)}{RESET}", file=stream)
f"{BOLD}{FG_CYAN}{left}{RESET}" print(f"{FG_TEAL}{link.center(width)}{RESET}", file=stream)
f"{' ' * left_gap}" print(f"{DIM}{FG_GRAY}{rule}{RESET}", file=stream)
f"{FG_GRAY}{center}{RESET}" else:
f"{' ' * right_gap}" print(rule, file=stream)
f"{BOLD}{FG_TEAL}{right}{RESET}" for line in art:
f"{DIM}{FG_GRAY}{RESET}" print(line.center(width), file=stream)
) print(version_line.center(width), file=stream)
print(link.center(width), file=stream)
print(rule, file=stream)
print(top, file=stream)
print(mid, file=stream)
print(bot, file=stream)
stream.flush() stream.flush()
+5
View File
@@ -0,0 +1,5 @@
from .proxy_server import ProxyServer
from .proxy_support import *
from .socks5 import *
from .mitm import *
__all__ = ["ProxyServer"]
+3 -3
View File
@@ -24,10 +24,10 @@ from cryptography.x509.oid import NameOID
log = logging.getLogger("MITM") log = logging.getLogger("MITM")
# CA lives at the project root (../ca/ relative to this file in src/). # Keep the CA at repository root so docs/installer paths stay stable.
# The installed trusted root was generated there; keep using it.
_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) _THIS_DIR = os.path.dirname(os.path.abspath(__file__))
_PROJECT_ROOT = os.path.dirname(_THIS_DIR) _SRC_DIR = os.path.dirname(_THIS_DIR)
_PROJECT_ROOT = os.path.dirname(_SRC_DIR)
CA_DIR = os.path.join(_PROJECT_ROOT, "ca") CA_DIR = os.path.join(_PROJECT_ROOT, "ca")
CA_KEY_FILE = os.path.join(CA_DIR, "ca.key") CA_KEY_FILE = os.path.join(CA_DIR, "ca.key")
CA_CERT_FILE = os.path.join(CA_DIR, "ca.crt") CA_CERT_FILE = os.path.join(CA_DIR, "ca.crt")
+58 -359
View File
@@ -13,18 +13,14 @@ import socket
import ssl import ssl
import time import time
import ipaddress import ipaddress
from urllib.parse import urlparse
try: try:
import certifi import certifi
except Exception: # optional dependency fallback except Exception: # optional dependency fallback
certifi = None certifi = None
from constants import ( from core.constants import (
CACHE_MAX_MB, CACHE_MAX_MB,
CACHE_TTL_MAX,
CACHE_TTL_STATIC_LONG,
CACHE_TTL_STATIC_MED,
CLIENT_IDLE_TIMEOUT, CLIENT_IDLE_TIMEOUT,
GOOGLE_DIRECT_ALLOW_EXACT, GOOGLE_DIRECT_ALLOW_EXACT,
GOOGLE_DIRECT_ALLOW_SUFFIXES, GOOGLE_DIRECT_ALLOW_SUFFIXES,
@@ -36,132 +32,29 @@ from constants import (
MAX_HEADER_BYTES, MAX_HEADER_BYTES,
MAX_REQUEST_BODY_BYTES, MAX_REQUEST_BODY_BYTES,
SNI_REWRITE_SUFFIXES, SNI_REWRITE_SUFFIXES,
STATIC_EXTS,
TCP_CONNECT_TIMEOUT, TCP_CONNECT_TIMEOUT,
TRACE_HOST_SUFFIXES, TRACE_HOST_SUFFIXES,
UNCACHEABLE_HEADER_NAMES, UNCACHEABLE_HEADER_NAMES,
) )
from domain_fronter import DomainFronter from relay.domain_fronter import DomainFronter
from .socks5 import negotiate_socks5
from .proxy_support import (
ResponseCache,
cors_preflight_response,
has_unsupported_transfer_encoding,
header_value,
host_matches_rules,
inject_cors_headers,
is_ip_literal,
load_host_rules,
log_response_summary,
parse_content_length,
)
from relay.relay_response import split_raw_response
log = logging.getLogger("Proxy") log = logging.getLogger("Proxy")
def _is_ip_literal(host: str) -> bool:
"""True for IPv4/IPv6 literals (strips brackets around IPv6)."""
h = host.strip("[]")
try:
ipaddress.ip_address(h)
return True
except ValueError:
return False
def _parse_content_length(header_block: bytes) -> int:
"""Return Content-Length or 0. Matches only the exact header name."""
for raw_line in header_block.split(b"\r\n"):
name, sep, value = raw_line.partition(b":")
if not sep:
continue
if name.strip().lower() == b"content-length":
try:
return int(value.strip())
except ValueError:
return 0
return 0
def _has_unsupported_transfer_encoding(header_block: bytes) -> bool:
"""True when the request uses Transfer-Encoding, which we don't stream."""
for raw_line in header_block.split(b"\r\n"):
name, sep, value = raw_line.partition(b":")
if not sep:
continue
if name.strip().lower() != b"transfer-encoding":
continue
encodings = [
token.strip().lower()
for token in value.decode(errors="replace").split(",")
if token.strip()
]
return any(token != "identity" for token in encodings)
return False
class ResponseCache:
"""Simple LRU response cache — avoids repeated relay calls."""
def __init__(self, max_mb: int = 50):
self._store: dict[str, tuple[bytes, float]] = {}
self._size = 0
self._max = max_mb * 1024 * 1024
self.hits = 0
self.misses = 0
def get(self, url: str) -> bytes | None:
entry = self._store.get(url)
if not entry:
self.misses += 1
return None
raw, expires = entry
if time.time() > expires:
self._size -= len(raw)
del self._store[url]
self.misses += 1
return None
self.hits += 1
return raw
def put(self, url: str, raw_response: bytes, ttl: int = 300):
size = len(raw_response)
if size > self._max // 4 or size == 0:
return
# Evict oldest to make room
while self._size + size > self._max and self._store:
oldest = next(iter(self._store))
self._size -= len(self._store[oldest][0])
del self._store[oldest]
if url in self._store:
self._size -= len(self._store[url][0])
self._store[url] = (raw_response, time.time() + ttl)
self._size += size
@staticmethod
def parse_ttl(raw_response: bytes, url: str) -> int:
"""Determine cache TTL from response headers and URL."""
hdr_end = raw_response.find(b"\r\n\r\n")
if hdr_end < 0:
return 0
hdr = raw_response[:hdr_end].decode(errors="replace").lower()
# Don't cache errors or non-200
if b"HTTP/1.1 200" not in raw_response[:20]:
return 0
if "no-store" in hdr or "private" in hdr or "set-cookie:" in hdr:
return 0
# Explicit max-age
m = re.search(r"max-age=(\d+)", hdr)
if m:
return min(int(m.group(1)), CACHE_TTL_MAX)
# Heuristic by content type / extension
path = url.split("?")[0].lower()
for ext in STATIC_EXTS:
if path.endswith(ext):
return CACHE_TTL_STATIC_LONG
ct_m = re.search(r"content-type:\s*([^\r\n]+)", hdr)
ct = ct_m.group(1) if ct_m else ""
if "image/" in ct or "font/" in ct:
return CACHE_TTL_STATIC_LONG
if "text/css" in ct or "javascript" in ct:
return CACHE_TTL_STATIC_MED
if "text/html" in ct or "application/json" in ct:
return 0 # don't cache dynamic content by default
return 0
class ProxyServer: class ProxyServer:
# Pulled from constants.py so users can override any subset via config. # Pulled from constants.py so users can override any subset via config.
_GOOGLE_DIRECT_EXACT_EXCLUDE = GOOGLE_DIRECT_EXACT_EXCLUDE _GOOGLE_DIRECT_EXACT_EXCLUDE = GOOGLE_DIRECT_EXACT_EXCLUDE
@@ -246,8 +139,8 @@ class ProxyServer:
# bypass_hosts — route directly (no MITM, no relay) # bypass_hosts — route directly (no MITM, no relay)
# Both accept exact hostnames and leading-dot suffix patterns, # Both accept exact hostnames and leading-dot suffix patterns,
# e.g. ".local" matches any *.local domain. # e.g. ".local" matches any *.local domain.
self._block_hosts = self._load_host_rules(config.get("block_hosts", [])) self._block_hosts = load_host_rules(config.get("block_hosts", []))
self._bypass_hosts = self._load_host_rules(config.get("bypass_hosts", [])) self._bypass_hosts = load_host_rules(config.get("bypass_hosts", []))
# Route YouTube through the relay when requested; the Google frontend # Route YouTube through the relay when requested; the Google frontend
# IP can enforce SafeSearch on the SNI-rewrite path. # IP can enforce SafeSearch on the SNI-rewrite path.
@@ -261,7 +154,7 @@ class ProxyServer:
self._SNI_REWRITE_SUFFIXES = SNI_REWRITE_SUFFIXES self._SNI_REWRITE_SUFFIXES = SNI_REWRITE_SUFFIXES
try: try:
from mitm import MITMCertManager from .mitm import MITMCertManager
self.mitm = MITMCertManager() self.mitm = MITMCertManager()
except ImportError: except ImportError:
log.error("Apps Script relay requires the 'cryptography' package.") log.error("Apps Script relay requires the 'cryptography' package.")
@@ -319,135 +212,21 @@ class ProxyServer:
if task is not None: if task is not None:
self._client_tasks.discard(task) self._client_tasks.discard(task)
@staticmethod
def _load_host_rules(raw) -> tuple[set[str], tuple[str, ...]]:
"""Accept a list of host strings; return (exact_set, suffix_tuple).
A rule starting with '.' (e.g. ".internal") is a suffix rule.
Everything else is treated as an exact match. Case-insensitive.
"""
exact: set[str] = set()
suffixes: list[str] = []
for item in raw or []:
h = str(item).strip().lower().rstrip(".")
if not h:
continue
if h.startswith("."):
suffixes.append(h)
else:
exact.add(h)
return exact, tuple(suffixes)
@staticmethod
def _host_matches_rules(host: str,
rules: tuple[set[str], tuple[str, ...]]) -> bool:
exact, suffixes = rules
h = host.lower().rstrip(".")
if h in exact:
return True
for s in suffixes:
if h.endswith(s):
return True
return False
def _is_blocked(self, host: str) -> bool: def _is_blocked(self, host: str) -> bool:
return self._host_matches_rules(host, self._block_hosts) return host_matches_rules(host, self._block_hosts)
def _is_bypassed(self, host: str) -> bool: def _is_bypassed(self, host: str) -> bool:
return self._host_matches_rules(host, self._bypass_hosts) return host_matches_rules(host, self._bypass_hosts)
@staticmethod
def _header_value(headers: dict | None, name: str) -> str:
if not headers:
return ""
for key, value in headers.items():
if key.lower() == name:
return str(value)
return ""
def _cache_allowed(self, method: str, url: str, def _cache_allowed(self, method: str, url: str,
headers: dict | None, body: bytes) -> bool: headers: dict | None, body: bytes) -> bool:
if method.upper() != "GET" or body: if method.upper() != "GET" or body:
return False return False
for name in UNCACHEABLE_HEADER_NAMES: for name in UNCACHEABLE_HEADER_NAMES:
if self._header_value(headers, name): if header_value(headers, name):
return False return False
return self.fronter._is_static_asset_url(url) return self.fronter._is_static_asset_url(url)
@classmethod
def _should_trace_host(cls, host: str) -> bool:
h = host.lower().rstrip(".")
return any(
token == h or token in h or h.endswith("." + token)
for token in cls._TRACE_HOST_SUFFIXES
)
def _log_response_summary(self, url: str, response: bytes):
status, headers, body = self.fronter._split_raw_response(response)
host = (urlparse(url).hostname or "").lower()
if status >= 300 or self._should_trace_host(host):
location = headers.get("location", "") or "-"
server = headers.get("server", "") or "-"
cf_ray = headers.get("cf-ray", "") or "-"
content_type = headers.get("content-type", "") or "-"
body_len = len(body)
body_hint = "-"
rate_limited = False
# Handle text-like responses (HTML, plain text, JSON…)
if ("text" in content_type.lower() or "json" in content_type.lower()) and body:
sample = body[:1200].decode(errors="replace").lower()
# --- Structured HTML title extraction ---
if "<title>" in sample and "</title>" in sample:
title = sample.split("<title>", 1)[1].split("</title>", 1)[0]
body_hint = title.strip()[:120] or "-"
# --- Known content patterns ---
elif "captcha" in sample:
body_hint = "captcha"
elif "turnstile" in sample:
body_hint = "turnstile"
elif "loading" in sample:
body_hint = "loading"
# --- Rate-limit / quota markers ---
rate_limit_markers = (
"too many",
"rate limit",
"quota",
"quota exceeded",
"request limit",
"دفعات زیاد",
"بیش از حد",
"سرویس در طول یک روز",
)
if any(m in sample for m in rate_limit_markers):
rate_limited = True
body_hint = "quota_exceeded"
log_msg = (
"RESP ← %s status=%s type=%s len=%s server=%s location=%s cf-ray=%s hint=%s"
)
log_args = (
host or url[:60],
status,
content_type,
body_len,
server,
location,
cf_ray,
body_hint,
)
if rate_limited:
log.warning("RATE LIMIT detected! " + log_msg, *log_args)
else:
log.info(log_msg, *log_args)
async def start(self): async def start(self):
http_srv = await asyncio.start_server(self._on_client, self.host, self.port) http_srv = await asyncio.start_server(self._on_client, self.host, self.port)
socks_srv = None socks_srv = None
@@ -534,7 +313,7 @@ class ProxyServer:
if line in (b"\r\n", b"\n", b""): if line in (b"\r\n", b"\n", b""):
break break
if _has_unsupported_transfer_encoding(header_block): if has_unsupported_transfer_encoding(header_block):
log.warning("Unsupported Transfer-Encoding on client request") log.warning("Unsupported Transfer-Encoding on client request")
writer.write( writer.write(
b"HTTP/1.1 501 Not Implemented\r\n" b"HTTP/1.1 501 Not Implemented\r\n"
@@ -575,52 +354,12 @@ class ProxyServer:
addr = writer.get_extra_info("peername") addr = writer.get_extra_info("peername")
task = self._track_current_task() task = self._track_current_task()
try: try:
header = await asyncio.wait_for(reader.readexactly(2), timeout=15) result = await negotiate_socks5(reader, writer)
ver, nmethods = header[0], header[1] if result is None:
if ver != 5:
return return
host, port = result
methods = await asyncio.wait_for(reader.readexactly(nmethods), timeout=10)
if 0x00 not in methods:
writer.write(b"\x05\xff")
await writer.drain()
return
writer.write(b"\x05\x00")
await writer.drain()
req = await asyncio.wait_for(reader.readexactly(4), timeout=15)
ver, cmd, _rsv, atyp = req
if ver != 5 or cmd != 0x01:
writer.write(b"\x05\x07\x00\x01\x00\x00\x00\x00\x00\x00")
await writer.drain()
return
if atyp == 0x01:
raw = await asyncio.wait_for(reader.readexactly(4), timeout=10)
host = socket.inet_ntoa(raw)
elif atyp == 0x03:
ln = (await asyncio.wait_for(reader.readexactly(1), timeout=10))[0]
host = (await asyncio.wait_for(reader.readexactly(ln), timeout=10)).decode(
errors="replace"
)
elif atyp == 0x04:
raw = await asyncio.wait_for(reader.readexactly(16), timeout=10)
host = socket.inet_ntop(socket.AF_INET6, raw)
else:
writer.write(b"\x05\x08\x00\x01\x00\x00\x00\x00\x00\x00")
await writer.drain()
return
port_raw = await asyncio.wait_for(reader.readexactly(2), timeout=10)
port = int.from_bytes(port_raw, "big")
log.info("SOCKS5 CONNECT → %s:%d", host, port) log.info("SOCKS5 CONNECT → %s:%d", host, port)
writer.write(b"\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00")
await writer.drain()
await self._handle_target_tunnel(host, port, reader, writer) await self._handle_target_tunnel(host, port, reader, writer)
except asyncio.IncompleteReadError: except asyncio.IncompleteReadError:
pass pass
except asyncio.CancelledError: except asyncio.CancelledError:
@@ -689,7 +428,7 @@ class ProxyServer:
# clients like Telegram speed up DC-rotation when we fail fast. # clients like Telegram speed up DC-rotation when we fail fast.
# We remember per-IP failures for a short while so subsequent # We remember per-IP failures for a short while so subsequent
# connects skip the doomed direct attempt. # connects skip the doomed direct attempt.
if _is_ip_literal(host): if is_ip_literal(host):
if not self._direct_temporarily_disabled(host): if not self._direct_temporarily_disabled(host):
log.info("Direct tunnel → %s:%d (IP literal)", host, port) log.info("Direct tunnel → %s:%d (IP literal)", host, port)
ok = await self._do_direct_tunnel( ok = await self._do_direct_tunnel(
@@ -1096,7 +835,7 @@ class ProxyServer:
# either. Telegram will rotate to another DC on its own; # either. Telegram will rotate to another DC on its own;
# failing fast here lets that happen sooner. # failing fast here lets that happen sooner.
# • Client CONNECTs but never speaks TLS (some probes). # • Client CONNECTs but never speaks TLS (some probes).
if _is_ip_literal(host) and port == 443: if is_ip_literal(host) and port == 443:
log.info( log.info(
"Non-TLS traffic on %s:%d (likely Telegram MTProto / " "Non-TLS traffic on %s:%d (likely Telegram MTProto / "
"obfuscated protocol). This DC appears blocked; the " "obfuscated protocol). This DC appears blocked; the "
@@ -1168,7 +907,7 @@ class ProxyServer:
# Read body # Read body
body = b"" body = b""
if _has_unsupported_transfer_encoding(header_block): if has_unsupported_transfer_encoding(header_block):
log.warning("Unsupported Transfer-Encoding → %s:%d", host, port) log.warning("Unsupported Transfer-Encoding → %s:%d", host, port)
writer.write( writer.write(
b"HTTP/1.1 501 Not Implemented\r\n" b"HTTP/1.1 501 Not Implemented\r\n"
@@ -1177,7 +916,7 @@ class ProxyServer:
) )
await writer.drain() await writer.drain()
break break
length = _parse_content_length(header_block) length = parse_content_length(header_block)
if length > MAX_REQUEST_BODY_BYTES: if length > MAX_REQUEST_BODY_BYTES:
raise ValueError(f"Request body too large: {length} bytes") raise ValueError(f"Request body too large: {length} bytes")
if length > 0: if length > 0:
@@ -1217,11 +956,11 @@ class ProxyServer:
log.info("MITM → %s %s", method, url) log.info("MITM → %s %s", method, url)
# ── CORS: extract relevant request headers ───────────── # ── CORS: extract relevant request headers ─────────────
origin = self._header_value(headers, "origin") origin = header_value(headers, "origin")
acr_method = self._header_value( acr_method = header_value(
headers, "access-control-request-method", headers, "access-control-request-method",
) )
acr_headers = self._header_value( acr_headers = header_value(
headers, "access-control-request-headers", headers, "access-control-request-headers",
) )
@@ -1234,7 +973,7 @@ class ProxyServer:
"CORS preflight → %s (responding locally)", "CORS preflight → %s (responding locally)",
url[:60], url[:60],
) )
writer.write(self._cors_preflight_response( writer.write(cors_preflight_response(
origin, acr_method, acr_headers, origin, acr_method, acr_headers,
)) ))
await writer.drain() await writer.drain()
@@ -1276,9 +1015,15 @@ class ProxyServer:
# browser blocks the response even though the relay fetched # browser blocks the response even though the relay fetched
# it successfully. # it successfully.
if origin and response: if origin and response:
response = self._inject_cors_headers(response, origin) response = inject_cors_headers(response, origin)
self._log_response_summary(url, response) log_response_summary(
logger=log,
split_raw_response=split_raw_response,
trace_suffixes=self._TRACE_HOST_SUFFIXES,
url=url,
response=response,
)
writer.write(response) writer.write(response)
await writer.drain() await writer.drain()
@@ -1294,59 +1039,7 @@ class ProxyServer:
break break
# ── CORS helpers ────────────────────────────────────────────── # ── CORS helpers ──────────────────────────────────────────────
# cors_preflight_response() and inject_cors_headers() live in proxy_support.
@staticmethod
def _cors_preflight_response(origin: str, acr_method: str,
acr_headers: str) -> bytes:
"""Build a 204 response that satisfies a CORS preflight locally.
Apps Script's UrlFetchApp does not support OPTIONS, so we have to
answer preflights here instead of forwarding them.
"""
allow_origin = origin or "*"
allow_methods = (
f"{acr_method}, GET, POST, PUT, DELETE, PATCH, OPTIONS"
if acr_method else
"GET, POST, PUT, DELETE, PATCH, OPTIONS"
)
allow_headers = acr_headers or "*"
return (
"HTTP/1.1 204 No Content\r\n"
f"Access-Control-Allow-Origin: {allow_origin}\r\n"
f"Access-Control-Allow-Methods: {allow_methods}\r\n"
f"Access-Control-Allow-Headers: {allow_headers}\r\n"
"Access-Control-Allow-Credentials: true\r\n"
"Access-Control-Max-Age: 86400\r\n"
"Vary: Origin\r\n"
"Content-Length: 0\r\n"
"\r\n"
).encode()
@staticmethod
def _inject_cors_headers(response: bytes, origin: str) -> bytes:
"""Strip existing Access-Control-* headers and add permissive ones.
Keeps the body untouched; only rewrites the header block. Using
the exact browser-supplied Origin (rather than "*") is required
when the request is credentialed (cookies, Authorization).
"""
sep = b"\r\n\r\n"
if sep not in response:
return response
header_section, body = response.split(sep, 1)
lines = header_section.decode(errors="replace").split("\r\n")
lines = [ln for ln in lines
if not ln.lower().startswith("access-control-")]
allow_origin = origin or "*"
lines += [
f"Access-Control-Allow-Origin: {allow_origin}",
"Access-Control-Allow-Credentials: true",
"Access-Control-Allow-Methods: GET, POST, PUT, DELETE, PATCH, OPTIONS",
"Access-Control-Allow-Headers: *",
"Access-Control-Expose-Headers: *",
"Vary: Origin",
]
return ("\r\n".join(lines) + "\r\n\r\n").encode() + body
async def _relay_smart(self, method, url, headers, body): async def _relay_smart(self, method, url, headers, body):
"""Choose optimal relay strategy based on request type. """Choose optimal relay strategy based on request type.
@@ -1388,7 +1081,7 @@ class ProxyServer:
for ext in self._download_extensions: for ext in self._download_extensions:
if path.endswith(ext): if path.endswith(ext):
return True return True
accept = self._header_value(headers, "accept").lower() accept = header_value(headers, "accept").lower()
if any(marker in accept for marker in self._DOWNLOAD_ACCEPT_MARKERS): if any(marker in accept for marker in self._DOWNLOAD_ACCEPT_MARKERS):
return True return True
return False return False
@@ -1421,7 +1114,7 @@ class ProxyServer:
async def _do_http(self, header_block: bytes, reader, writer): async def _do_http(self, header_block: bytes, reader, writer):
body = b"" body = b""
if _has_unsupported_transfer_encoding(header_block): if has_unsupported_transfer_encoding(header_block):
log.warning("Unsupported Transfer-Encoding on plain HTTP request") log.warning("Unsupported Transfer-Encoding on plain HTTP request")
writer.write( writer.write(
b"HTTP/1.1 501 Not Implemented\r\n" b"HTTP/1.1 501 Not Implemented\r\n"
@@ -1430,7 +1123,7 @@ class ProxyServer:
) )
await writer.drain() await writer.drain()
return return
length = _parse_content_length(header_block) length = parse_content_length(header_block)
if length > MAX_REQUEST_BODY_BYTES: if length > MAX_REQUEST_BODY_BYTES:
writer.write(b"HTTP/1.1 413 Content Too Large\r\n\r\n") writer.write(b"HTTP/1.1 413 Content Too Large\r\n\r\n")
await writer.drain() await writer.drain()
@@ -1453,12 +1146,12 @@ class ProxyServer:
headers[k.strip()] = v.strip() headers[k.strip()] = v.strip()
# ── CORS preflight over plain HTTP ───────────────────────────── # ── CORS preflight over plain HTTP ─────────────────────────────
origin = self._header_value(headers, "origin") origin = header_value(headers, "origin")
acr_method = self._header_value(headers, "access-control-request-method") acr_method = header_value(headers, "access-control-request-method")
acr_headers = self._header_value(headers, "access-control-request-headers") acr_headers = header_value(headers, "access-control-request-headers")
if method.upper() == "OPTIONS" and acr_method: if method.upper() == "OPTIONS" and acr_method:
log.debug("CORS preflight (HTTP) → %s (responding locally)", url[:60]) log.debug("CORS preflight (HTTP) → %s (responding locally)", url[:60])
writer.write(self._cors_preflight_response( writer.write(cors_preflight_response(
origin, acr_method, acr_headers, origin, acr_method, acr_headers,
)) ))
await writer.drain() await writer.drain()
@@ -1483,9 +1176,15 @@ class ProxyServer:
self._cache.put(url, response, ttl) self._cache.put(url, response, ttl)
if origin and response: if origin and response:
response = self._inject_cors_headers(response, origin) response = inject_cors_headers(response, origin)
self._log_response_summary(url, response) log_response_summary(
logger=log,
split_raw_response=split_raw_response,
trace_suffixes=self._TRACE_HOST_SUFFIXES,
url=url,
response=response,
)
writer.write(response) writer.write(response)
await writer.drain() await writer.drain()
+307
View File
@@ -0,0 +1,307 @@
"""
Proxy helper utilities: header parsing, host rule matching, response caching,
CORS injection, and response logging.
Extracted from proxy_server.py to separate pure helper logic from the
ProxyServer connection handler.
"""
import ipaddress
import logging
import re
import time
from urllib.parse import urlparse
from core.constants import (
CACHE_TTL_MAX,
CACHE_TTL_STATIC_LONG,
CACHE_TTL_STATIC_MED,
STATIC_EXTS,
)
__all__ = [
"is_ip_literal",
"parse_content_length",
"has_unsupported_transfer_encoding",
"load_host_rules",
"host_matches_rules",
"header_value",
"should_trace_host",
"log_response_summary",
"ResponseCache",
"cors_preflight_response",
"inject_cors_headers",
]
def is_ip_literal(host: str) -> bool:
"""True for IPv4/IPv6 literals (strips brackets around IPv6)."""
normalized = host.strip("[]")
try:
ipaddress.ip_address(normalized)
return True
except ValueError:
return False
def parse_content_length(header_block: bytes) -> int:
"""Return Content-Length or 0. Matches only the exact header name."""
for raw_line in header_block.split(b"\r\n"):
name, sep, value = raw_line.partition(b":")
if not sep:
continue
if name.strip().lower() == b"content-length":
try:
return int(value.strip())
except ValueError:
return 0
return 0
def has_unsupported_transfer_encoding(header_block: bytes) -> bool:
"""True when the request uses Transfer-Encoding, which we don't stream."""
for raw_line in header_block.split(b"\r\n"):
name, sep, value = raw_line.partition(b":")
if not sep:
continue
if name.strip().lower() != b"transfer-encoding":
continue
encodings = [
token.strip().lower()
for token in value.decode(errors="replace").split(",")
if token.strip()
]
return any(token != "identity" for token in encodings)
return False
def load_host_rules(raw) -> tuple[set[str], tuple[str, ...]]:
"""Accept a list of host strings; return (exact_set, suffix_tuple)."""
exact: set[str] = set()
suffixes: list[str] = []
for item in raw or []:
host = str(item).strip().lower().rstrip(".")
if not host:
continue
if host.startswith("."):
suffixes.append(host)
else:
exact.add(host)
return exact, tuple(suffixes)
def host_matches_rules(host: str, rules: tuple[set[str], tuple[str, ...]]) -> bool:
exact, suffixes = rules
normalized = host.lower().rstrip(".")
if normalized in exact:
return True
return any(normalized.endswith(suffix) for suffix in suffixes)
def header_value(headers: dict | None, name: str) -> str:
if not headers:
return ""
for key, value in headers.items():
if key.lower() == name:
return str(value)
return ""
def should_trace_host(host: str, trace_suffixes: tuple[str, ...]) -> bool:
normalized = host.lower().rstrip(".")
return any(
token == normalized or token in normalized or normalized.endswith("." + token)
for token in trace_suffixes
)
def log_response_summary(
*,
logger: logging.Logger,
split_raw_response,
trace_suffixes: tuple[str, ...],
url: str,
response: bytes,
) -> None:
status, headers, body = split_raw_response(response)
host = (urlparse(url).hostname or "").lower()
if status < 300 and not should_trace_host(host, trace_suffixes):
return
location = headers.get("location", "") or "-"
server = headers.get("server", "") or "-"
cf_ray = headers.get("cf-ray", "") or "-"
content_type = headers.get("content-type", "") or "-"
body_len = len(body)
body_hint = "-"
rate_limited = False
if ("text" in content_type.lower() or "json" in content_type.lower()) and body:
sample = body[:1200].decode(errors="replace").lower()
if "<title>" in sample and "</title>" in sample:
title = sample.split("<title>", 1)[1].split("</title>", 1)[0]
body_hint = title.strip()[:120] or "-"
elif "captcha" in sample:
body_hint = "captcha"
elif "turnstile" in sample:
body_hint = "turnstile"
elif "loading" in sample:
body_hint = "loading"
rate_limit_markers = (
"too many",
"rate limit",
"quota",
"quota exceeded",
"request limit",
"دفعات زیاد",
"بیش از حد",
"سرویس در طول یک روز",
)
if any(marker in sample for marker in rate_limit_markers):
rate_limited = True
body_hint = "quota_exceeded"
log_msg = (
"RESP <- %s status=%s type=%s len=%s server=%s location=%s cf-ray=%s hint=%s"
)
log_args = (
host or url[:60],
status,
content_type,
body_len,
server,
location,
cf_ray,
body_hint,
)
if rate_limited:
logger.warning("RATE LIMIT detected! " + log_msg, *log_args)
else:
logger.info(log_msg, *log_args)
class ResponseCache:
"""Simple LRU response cache for relayable static responses."""
def __init__(self, max_mb: int = 50):
self._store: dict[str, tuple[bytes, float]] = {}
self._size = 0
self._max = max_mb * 1024 * 1024
self.hits = 0
self.misses = 0
def get(self, url: str) -> bytes | None:
entry = self._store.get(url)
if not entry:
self.misses += 1
return None
raw, expires = entry
if time.time() > expires:
self._size -= len(raw)
del self._store[url]
self.misses += 1
return None
self.hits += 1
return raw
def put(self, url: str, raw_response: bytes, ttl: int = 300):
size = len(raw_response)
if size > self._max // 4 or size == 0:
return
while self._size + size > self._max and self._store:
oldest = next(iter(self._store))
self._size -= len(self._store[oldest][0])
del self._store[oldest]
if url in self._store:
self._size -= len(self._store[url][0])
self._store[url] = (raw_response, time.time() + ttl)
self._size += size
@staticmethod
def parse_ttl(raw_response: bytes, url: str) -> int:
"""Determine cache TTL from response headers and URL."""
hdr_end = raw_response.find(b"\r\n\r\n")
if hdr_end < 0:
return 0
hdr = raw_response[:hdr_end].decode(errors="replace").lower()
if b"HTTP/1.1 200" not in raw_response[:20]:
return 0
if "no-store" in hdr or "private" in hdr or "set-cookie:" in hdr:
return 0
max_age_match = re.search(r"max-age=(\d+)", hdr)
if max_age_match:
return min(int(max_age_match.group(1)), CACHE_TTL_MAX)
path = url.split("?")[0].lower()
for ext in STATIC_EXTS:
if path.endswith(ext):
return CACHE_TTL_STATIC_LONG
content_type_match = re.search(r"content-type:\s*([^\r\n]+)", hdr)
content_type = content_type_match.group(1) if content_type_match else ""
if "image/" in content_type or "font/" in content_type:
return CACHE_TTL_STATIC_LONG
if "text/css" in content_type or "javascript" in content_type:
return CACHE_TTL_STATIC_MED
if "text/html" in content_type or "application/json" in content_type:
return 0
return 0
# ── CORS helpers ──────────────────────────────────────────────────────────────
def cors_preflight_response(origin: str, acr_method: str, acr_headers: str) -> bytes:
"""Build a 204 response that satisfies a CORS preflight locally.
Apps Script's UrlFetchApp does not support OPTIONS, so preflights must
be answered here rather than forwarded to the relay.
"""
allow_origin = origin or "*"
allow_methods = (
f"{acr_method}, GET, POST, PUT, DELETE, PATCH, OPTIONS"
if acr_method else
"GET, POST, PUT, DELETE, PATCH, OPTIONS"
)
allow_headers = acr_headers or "*"
return (
"HTTP/1.1 204 No Content\r\n"
f"Access-Control-Allow-Origin: {allow_origin}\r\n"
f"Access-Control-Allow-Methods: {allow_methods}\r\n"
f"Access-Control-Allow-Headers: {allow_headers}\r\n"
"Access-Control-Allow-Credentials: true\r\n"
"Access-Control-Max-Age: 86400\r\n"
"Vary: Origin\r\n"
"Content-Length: 0\r\n"
"\r\n"
).encode()
def inject_cors_headers(response: bytes, origin: str) -> bytes:
"""Strip existing Access-Control-* headers and inject permissive ones.
Keeps the body untouched; only rewrites the header block. Using the
exact browser-supplied Origin (rather than "*") is required when the
request is credentialed (cookies, Authorization).
"""
sep = b"\r\n\r\n"
if sep not in response:
return response
header_section, body = response.split(sep, 1)
lines = header_section.decode(errors="replace").split("\r\n")
lines = [ln for ln in lines if not ln.lower().startswith("access-control-")]
allow_origin = origin or "*"
lines += [
f"Access-Control-Allow-Origin: {allow_origin}",
"Access-Control-Allow-Credentials: true",
"Access-Control-Allow-Methods: GET, POST, PUT, DELETE, PATCH, OPTIONS",
"Access-Control-Allow-Headers: *",
"Access-Control-Expose-Headers: *",
"Vary: Origin",
]
return ("\r\n".join(lines) + "\r\n\r\n").encode() + body
+88
View File
@@ -0,0 +1,88 @@
"""
SOCKS5 protocol negotiation helpers.
Implements the RFC 1928 handshake for CONNECT (TCP BIND) requests.
Only no-authentication (method 0x00) and CONNECT (cmd 0x01) are supported,
which covers all standard proxy use cases (HTTPS, HTTP, arbitrary TCP).
Usage::
host, port = await negotiate_socks5(reader, writer)
# host/port are None if negotiation failed (caller should close)
"""
from __future__ import annotations
import asyncio
import socket
__all__ = ["negotiate_socks5"]
async def negotiate_socks5(
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> tuple[str, int] | None:
"""Perform a SOCKS5 handshake and return the requested (host, port).
Sends protocol-level replies directly to *writer*. Returns ``None``
and leaves the connection in a closed state if negotiation fails at
any step (unsupported version, method, command, or address type).
Raises:
asyncio.IncompleteReadError: if the client closes the connection
mid-handshake.
asyncio.TimeoutError: propagated from the individual ``wait_for``
calls so the caller can log it separately.
"""
# ── Auth negotiation ──────────────────────────────────────────
header = await asyncio.wait_for(reader.readexactly(2), timeout=15)
ver, nmethods = header[0], header[1]
if ver != 5:
return None
methods = await asyncio.wait_for(reader.readexactly(nmethods), timeout=10)
if 0x00 not in methods:
# No acceptable method — reject
writer.write(b"\x05\xff")
await writer.drain()
return None
# Accept: no authentication required
writer.write(b"\x05\x00")
await writer.drain()
# ── Request ───────────────────────────────────────────────────
req = await asyncio.wait_for(reader.readexactly(4), timeout=15)
ver, cmd, _rsv, atyp = req
if ver != 5 or cmd != 0x01:
# Only CONNECT (0x01) is supported
writer.write(b"\x05\x07\x00\x01\x00\x00\x00\x00\x00\x00")
await writer.drain()
return None
# ── Address parsing ───────────────────────────────────────────
if atyp == 0x01: # IPv4
raw = await asyncio.wait_for(reader.readexactly(4), timeout=10)
host = socket.inet_ntoa(raw)
elif atyp == 0x03: # Domain name
ln = (await asyncio.wait_for(reader.readexactly(1), timeout=10))[0]
host = (
await asyncio.wait_for(reader.readexactly(ln), timeout=10)
).decode(errors="replace")
elif atyp == 0x04: # IPv6
raw = await asyncio.wait_for(reader.readexactly(16), timeout=10)
host = socket.inet_ntop(socket.AF_INET6, raw)
else:
writer.write(b"\x05\x08\x00\x01\x00\x00\x00\x00\x00\x00")
await writer.drain()
return None
port_raw = await asyncio.wait_for(reader.readexactly(2), timeout=10)
port = int.from_bytes(port_raw, "big")
# ── Success reply ─────────────────────────────────────────────
writer.write(b"\x05\x00\x00\x01\x00\x00\x00\x00\x00\x00")
await writer.drain()
return host, port
+6
View File
@@ -0,0 +1,6 @@
from .domain_fronter import DomainFronter
from .relay_response import *
from .fronting_support import *
from .h2_transport import *
from .http_reader import *
__all__ = ["DomainFronter"]
@@ -20,7 +20,6 @@ import ssl
import statistics import statistics
import tempfile import tempfile
import time import time
from dataclasses import dataclass
from urllib.parse import urlparse from urllib.parse import urlparse
try: try:
@@ -28,13 +27,12 @@ try:
except Exception: # optional dependency fallback except Exception: # optional dependency fallback
certifi = None certifi = None
import codec from core import codec
from constants import ( from core.constants import (
BATCH_MAX, BATCH_MAX,
BATCH_WINDOW_MACRO, BATCH_WINDOW_MACRO,
BATCH_WINDOW_MICRO, BATCH_WINDOW_MICRO,
CONN_TTL, CONN_TTL,
FRONT_SNI_POOL_GOOGLE,
MAX_RESPONSE_BODY_BYTES, MAX_RESPONSE_BODY_BYTES,
POOL_MAX, POOL_MAX,
POOL_MIN_IDLE, POOL_MIN_IDLE,
@@ -48,52 +46,33 @@ from constants import (
TLS_CONNECT_TIMEOUT, TLS_CONNECT_TIMEOUT,
WARM_POOL_COUNT, WARM_POOL_COUNT,
) )
from .fronting_support import (
HostStat,
build_sni_pool,
format_bytes_human,
format_elapsed_short,
parse_content_range,
progress_line,
render_progress_bar,
spool_read,
spool_write,
validate_range_response,
)
from .relay_response import (
classify_relay_error,
error_response,
extract_apps_script_user_html,
load_relay_json,
parse_relay_json,
parse_relay_response,
split_raw_response,
split_set_cookie,
)
from .http_reader import read_http_response
log = logging.getLogger("Fronter") log = logging.getLogger("Fronter")
@dataclass
class HostStat:
"""Per-host traffic accounting — useful for profiling slow / heavy sites."""
requests: int = 0
cache_hits: int = 0
bytes: int = 0
total_latency_ns: int = 0
errors: int = 0
def _build_sni_pool(front_domain: str, overrides: list | None) -> list[str]:
"""Build the list of SNIs to rotate through on new outbound TLS handshakes.
Priority:
1. Explicit `front_domains` list in config (overrides).
2. If `front_domain` is a Google property, use FRONT_SNI_POOL_GOOGLE
(all share the same Google edge IP, so rotation is invisible to
the relay but breaks DPI's "always www.google.com" heuristic).
3. Fall back to the single configured `front_domain`.
"""
if overrides:
seen: set[str] = set()
out: list[str] = []
for item in overrides:
host = str(item).strip().lower().rstrip(".")
if host and host not in seen:
seen.add(host)
out.append(host)
if out:
return out
fd = (front_domain or "").lower().rstrip(".")
if fd.endswith(".google.com") or fd == "google.com":
# For Google fronting we prefer the curated pool order, which can be
# latency-biased for common censored networks. Include the configured
# front_domain if it is custom, but do not pin it first.
pool = list(FRONT_SNI_POOL_GOOGLE)
if fd and fd not in pool:
pool.insert(0, fd)
return pool
return [fd] if fd else ["www.google.com"]
class DomainFronter: class DomainFronter:
_STATIC_EXTS = STATIC_EXTS _STATIC_EXTS = STATIC_EXTS
_H2_FAILURE_COOLDOWN = 60.0 _H2_FAILURE_COOLDOWN = 60.0
@@ -114,7 +93,7 @@ class DomainFronter:
self.sni_host = config.get("front_domain", "www.google.com") self.sni_host = config.get("front_domain", "www.google.com")
# SNI rotation pool — rotated per new outbound TLS connection so # SNI rotation pool — rotated per new outbound TLS connection so
# DPI systems can't fingerprint traffic as "always one SNI". # DPI systems can't fingerprint traffic as "always one SNI".
self._sni_hosts = _build_sni_pool( self._sni_hosts = build_sni_pool(
self.sni_host, config.get("front_domains"), self.sni_host, config.get("front_domains"),
) )
self._sni_idx = 0 self._sni_idx = 0
@@ -171,6 +150,10 @@ class DomainFronter:
self._keepalive_task: asyncio.Task | None = None self._keepalive_task: asyncio.Task | None = None
self._warm_task: asyncio.Task | None = None self._warm_task: asyncio.Task | None = None
self._bg_tasks: set[asyncio.Task] = set() self._bg_tasks: set[asyncio.Task] = set()
# Set by _do_warm() when the initial TLS connection batch is open.
# The very first relay() call awaits this (with a short timeout) so it
# never dispatches a request onto a completely cold pool.
self._pool_ready = asyncio.Event()
# Batch collector for grouping concurrent relay() calls # Batch collector for grouping concurrent relay() calls
self._batch_lock = asyncio.Lock() self._batch_lock = asyncio.Lock()
@@ -192,7 +175,7 @@ class DomainFronter:
# HTTP/2 multiplexing — one connection handles all requests # HTTP/2 multiplexing — one connection handles all requests
self._h2 = None self._h2 = None
try: try:
from h2_transport import H2Transport, H2_AVAILABLE from .h2_transport import H2Transport, H2_AVAILABLE
if H2_AVAILABLE: if H2_AVAILABLE:
self._h2 = H2Transport( self._h2 = H2Transport(
self.connect_host, self.sni_host, self.verify_ssl, self.connect_host, self.sni_host, self.verify_ssl,
@@ -626,87 +609,6 @@ class DomainFronter:
lines.append("") lines.append("")
return "\r\n".join(lines).encode() return "\r\n".join(lines).encode()
@staticmethod
def _parse_content_range(value: str) -> tuple[int, int, int] | None:
match = re.match(r"^\s*bytes\s+(\d+)-(\d+)/(\d+)\s*$", value or "")
if not match:
return None
start, end, total = (int(group) for group in match.groups())
if start < 0 or end < start or total <= end:
return None
return start, end, total
@classmethod
def _validate_range_response(cls, status: int, resp_headers: dict,
body: bytes, start_off: int,
end_off: int,
total_size: int | None = None) -> str | None:
if status != 206:
return f"status {status}"
parsed = cls._parse_content_range(resp_headers.get("content-range", ""))
if not parsed:
return "missing/invalid Content-Range"
got_start, got_end, got_total = parsed
if got_start != start_off or got_end != end_off:
return f"Content-Range mismatch {got_start}-{got_end}"
if total_size is not None and got_total != total_size:
return f"Content-Range total mismatch {got_total}/{total_size}"
expected = end_off - start_off + 1
if len(body) != expected:
return f"short chunk {len(body)}/{expected} B"
return None
@staticmethod
def _spool_write(file_obj, offset: int, data: bytes) -> None:
file_obj.seek(offset)
file_obj.write(data)
file_obj.flush()
@staticmethod
def _spool_read(file_obj, offset: int, size: int) -> bytes:
file_obj.seek(offset)
return file_obj.read(size)
@staticmethod
def _format_bytes_human(num_bytes: int) -> str:
value = float(max(0, num_bytes))
units = ("B", "KiB", "MiB", "GiB", "TiB")
unit = units[0]
for unit in units:
if value < 1024.0 or unit == units[-1]:
break
value /= 1024.0
if unit == "B":
return f"{int(value)} {unit}"
return f"{value:.1f} {unit}"
@staticmethod
def _format_elapsed_short(seconds: float) -> str:
total = max(0, int(seconds))
minutes, secs = divmod(total, 60)
hours, minutes = divmod(minutes, 60)
if hours:
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
return f"{minutes:02d}:{secs:02d}"
@staticmethod
def _render_progress_bar(done: int, total: int, width: int = 34) -> str:
if total <= 0:
return "[" + ("-" * width) + "]"
ratio = max(0.0, min(1.0, done / total))
filled = min(width, int(round(ratio * width)))
return "[" + ("#" * filled) + ("-" * (width - filled)) + "]"
@classmethod
def _progress_line(cls, *, elapsed: float, done: int, total: int,
speed_bytes_per_sec: float) -> str:
return (
f"[{cls._format_elapsed_short(elapsed)}] "
f"{cls._render_progress_bar(done, total)} "
f"{cls._format_bytes_human(done)} / {cls._format_bytes_human(total)} "
f"({cls._format_bytes_human(int(speed_bytes_per_sec))}/s)"
)
async def _relay_payload_h1(self, payload: dict) -> bytes: async def _relay_payload_h1(self, payload: dict) -> bytes:
attempts = self._retry_attempts_for_payload(payload) attempts = self._retry_attempts_for_payload(payload)
async with self._semaphore: async with self._semaphore:
@@ -745,7 +647,7 @@ class DomainFronter:
await asyncio.sleep(0.3 * (attempt + 1)) await asyncio.sleep(0.3 * (attempt + 1))
continue continue
last_status, _, _ = self._split_raw_response(last_raw) last_status, _, _ = split_raw_response(last_raw)
if last_status == 206 or last_status < 500: if last_status == 206 or last_status < 500:
return last_raw return last_raw
if attempt < max_tries - 1: if attempt < max_tries - 1:
@@ -1020,7 +922,7 @@ class DomainFronter:
timeout=15, timeout=15,
) )
dt = (time.perf_counter() - t0) * 1000 dt = (time.perf_counter() - t0) * 1000
data = self._load_relay_json(body.decode(errors="replace")) data = load_relay_json(body.decode(errors="replace"))
if "s" in data: if "s" in data:
self._dev_available = True self._dev_available = True
log.info("/dev fast path active (%.0fms, no redirect)", dt) log.info("/dev fast path active (%.0fms, no redirect)", dt)
@@ -1116,6 +1018,9 @@ class DomainFronter:
results = await asyncio.gather(*coros, return_exceptions=True) results = await asyncio.gather(*coros, return_exceptions=True)
opened = sum(1 for r in results if not isinstance(r, Exception)) opened = sum(1 for r in results if not isinstance(r, Exception))
log.info("Pre-warmed %d/%d TLS connections", opened, count) log.info("Pre-warmed %d/%d TLS connections", opened, count)
# Signal that at least the pool-open phase finished so relay() can
# stop waiting on the first request.
self._pool_ready.set()
def _auth_header(self) -> str: def _auth_header(self) -> str:
return f"X-Auth-Key: {self.auth_key}\r\n" if self.auth_key else "" return f"X-Auth-Key: {self.auth_key}\r\n" if self.auth_key else ""
@@ -1238,8 +1143,8 @@ class DomainFronter:
# _parse_relay_response will decode it into the final HTTP response. # _parse_relay_response will decode it into the final HTTP response.
# But we need to unwrap one level: Apps Script gives us exit node HTTP # But we need to unwrap one level: Apps Script gives us exit node HTTP
# response body (which is itself a relay JSON), so parse twice. # response body (which is itself a relay JSON), so parse twice.
_, _, apps_script_body = self._split_raw_response(raw) _, _, apps_script_body = split_raw_response(raw)
result = self._parse_relay_response(apps_script_body) result = parse_relay_response(apps_script_body, self._max_response_body_bytes)
log.debug("Exit node relay OK: %s", payload.get("u", "")[:80]) log.debug("Exit node relay OK: %s", payload.get("u", "")[:80])
return result return result
@@ -1261,6 +1166,20 @@ class DomainFronter:
if not self._warmed: if not self._warmed:
await self._warm_pool() await self._warm_pool()
# On the very first request, wait up to one TLS-connect-timeout for the
# pool to have at least one open connection. This prevents the first
# browser request from racing onto a completely cold pool. The wait is
# capped so a slow network never blocks the user indefinitely — the
# normal retry/fallback path handles it from there.
if not self._pool_ready.is_set():
try:
await asyncio.wait_for(
asyncio.shield(self._pool_ready.wait()),
timeout=self._tls_connect_timeout,
)
except asyncio.TimeoutError:
log.debug("Pool warm timeout — proceeding with cold pool")
payload = self._build_payload(method, url, headers, body) payload = self._build_payload(method, url, headers, body)
# Exit node short-circuit: route to non-Google IP before Apps Script # Exit node short-circuit: route to non-Google IP before Apps Script
@@ -1383,7 +1302,7 @@ class DomainFronter:
# Probe: first chunk with Range header # Probe: first chunk with Range header
first_resp = await self._range_probe(url, headers, 0, chunk_size - 1) first_resp = await self._range_probe(url, headers, 0, chunk_size - 1)
status, resp_hdrs, resp_body = self._split_raw_response(first_resp) status, resp_hdrs, resp_body = split_raw_response(first_resp)
# No range support → return the single response as-is (status 200 # No range support → return the single response as-is (status 200
# from the origin). The client sent a plain GET, so 200 is what it # from the origin). The client sent a plain GET, so 200 is what it
@@ -1392,19 +1311,19 @@ class DomainFronter:
return first_resp return first_resp
# Parse total size from Content-Range: "bytes 0-262143/1048576" # Parse total size from Content-Range: "bytes 0-262143/1048576"
parsed_range = self._parse_content_range(resp_hdrs.get("content-range", "")) parsed_range = parse_content_range(resp_hdrs.get("content-range", ""))
if not parsed_range: if not parsed_range:
# Can't parse — downgrade to 200 so the client (which sent a # Can't parse — downgrade to 200 so the client (which sent a
# plain GET) doesn't get confused by 206 + Content-Range. # plain GET) doesn't get confused by 206 + Content-Range.
return self._rewrite_206_to_200(first_resp) return self._rewrite_206_to_200(first_resp)
first_start, first_end, total_size = parsed_range first_start, first_end, total_size = parsed_range
first_err = self._validate_range_response( first_err = validate_range_response(
status, resp_hdrs, resp_body, first_start, first_end, total_size, status, resp_hdrs, resp_body, first_start, first_end, total_size,
) )
if first_start != 0 or first_err: if first_start != 0 or first_err:
return self._rewrite_206_to_200(first_resp) return self._rewrite_206_to_200(first_resp)
if total_size > self._max_response_body_bytes: if total_size > self._max_response_body_bytes:
return self._error_response( return error_response(
502, 502,
"Relay response exceeds cap " "Relay response exceeds cap "
f"({self._max_response_body_bytes} bytes). " f"({self._max_response_body_bytes} bytes). "
@@ -1464,8 +1383,8 @@ class DomainFronter:
for attempt in range(max_tries): for attempt in range(max_tries):
try: try:
raw = await self._relay_payload_h1(payload) raw = await self._relay_payload_h1(payload)
chunk_status, chunk_headers, chunk_body = self._split_raw_response(raw) chunk_status, chunk_headers, chunk_body = split_raw_response(raw)
err = self._validate_range_response( err = validate_range_response(
chunk_status, chunk_headers, chunk_body, chunk_status, chunk_headers, chunk_body,
s, e, total_size, s, e, total_size,
) )
@@ -1483,7 +1402,7 @@ class DomainFronter:
speed_bps = completed_bytes / elapsed speed_bps = completed_bytes / elapsed
log.info( log.info(
"Parallel download progress: %s [%d/%d chunks]", "Parallel download progress: %s [%d/%d chunks]",
self._progress_line( progress_line(
elapsed=elapsed, elapsed=elapsed,
done=completed_bytes, done=completed_bytes,
total=total_bytes, total=total_bytes,
@@ -1515,14 +1434,14 @@ class DomainFronter:
for i, r in enumerate(results): for i, r in enumerate(results):
if isinstance(r, Exception): if isinstance(r, Exception):
log.error("Range chunk %d failed: %s", i, r) log.error("Range chunk %d failed: %s", i, r)
return self._error_response(502, f"Parallel download failed: {r}") return error_response(502, f"Parallel download failed: {r}")
parts.append(r) parts.append(r)
full_body = b"".join(parts) full_body = b"".join(parts)
kbs = (len(full_body) / 1024) / elapsed if elapsed > 0 else 0 kbs = (len(full_body) / 1024) / elapsed if elapsed > 0 else 0
log.info( log.info(
"Parallel download complete: %s", "Parallel download complete: %s",
self._progress_line( progress_line(
elapsed=elapsed, elapsed=elapsed,
done=len(full_body), done=len(full_body),
total=len(full_body), total=len(full_body),
@@ -1557,7 +1476,7 @@ class DomainFronter:
""" """
first_resp = await self._range_probe(url, headers, 0, chunk_size - 1) first_resp = await self._range_probe(url, headers, 0, chunk_size - 1)
status, resp_hdrs, resp_body = self._split_raw_response(first_resp) status, resp_hdrs, resp_body = split_raw_response(first_resp)
if status != 206: if status != 206:
log.info( log.info(
"Streaming download fallback: initial probe returned %s for %s", "Streaming download fallback: initial probe returned %s for %s",
@@ -1565,7 +1484,7 @@ class DomainFronter:
) )
return False return False
parsed_range = self._parse_content_range(resp_hdrs.get("content-range", "")) parsed_range = parse_content_range(resp_hdrs.get("content-range", ""))
if not parsed_range: if not parsed_range:
log.info( log.info(
"Streaming download fallback: missing/invalid Content-Range for %s", "Streaming download fallback: missing/invalid Content-Range for %s",
@@ -1573,7 +1492,7 @@ class DomainFronter:
) )
return False return False
first_start, first_end, total_size = parsed_range first_start, first_end, total_size = parsed_range
first_err = self._validate_range_response( first_err = validate_range_response(
status, resp_hdrs, resp_body, first_start, first_end, total_size, status, resp_hdrs, resp_body, first_start, first_end, total_size,
) )
if first_start != 0 or first_err: if first_start != 0 or first_err:
@@ -1642,7 +1561,7 @@ class DomainFronter:
speed_bps = delivered_bytes / elapsed speed_bps = delivered_bytes / elapsed
log.info( log.info(
"Parallel download progress: %s [%d/%d chunks]", "Parallel download progress: %s [%d/%d chunks]",
self._progress_line( progress_line(
elapsed=elapsed, elapsed=elapsed,
done=delivered_bytes, done=delivered_bytes,
total=total_size, total=total_size,
@@ -1666,15 +1585,15 @@ class DomainFronter:
return return
try: try:
raw = await self._relay_payload_h1(payload) raw = await self._relay_payload_h1(payload)
chunk_status, chunk_headers, chunk_body = self._split_raw_response(raw) chunk_status, chunk_headers, chunk_body = split_raw_response(raw)
err = self._validate_range_response( err = validate_range_response(
chunk_status, chunk_headers, chunk_body, chunk_status, chunk_headers, chunk_body,
start_off, end_off, total_size, start_off, end_off, total_size,
) )
if err is None: if err is None:
async with file_lock: async with file_lock:
await asyncio.to_thread( await asyncio.to_thread(
self._spool_write, temp_file, start_off, chunk_body, spool_write, temp_file, start_off, chunk_body,
) )
ready[index].set() ready[index].set()
return return
@@ -1708,7 +1627,7 @@ class DomainFronter:
expected = end_off - start_off + 1 expected = end_off - start_off + 1
async with file_lock: async with file_lock:
chunk = await asyncio.to_thread( chunk = await asyncio.to_thread(
self._spool_read, temp_file, start_off, expected, spool_read, temp_file, start_off, expected,
) )
if len(chunk) != expected: if len(chunk) != expected:
raise RuntimeError( raise RuntimeError(
@@ -1724,7 +1643,7 @@ class DomainFronter:
elapsed = max(0.001, time.perf_counter() - t0) elapsed = max(0.001, time.perf_counter() - t0)
log.info( log.info(
"Parallel streaming download complete: %s", "Parallel streaming download complete: %s",
self._progress_line( progress_line(
elapsed=elapsed, elapsed=elapsed,
done=total_size, done=total_size,
total=total_size, total=total_size,
@@ -1935,7 +1854,7 @@ class DomainFronter:
future.set_result(result) future.set_result(result)
except Exception as e: except Exception as e:
if not future.done(): if not future.done():
future.set_result(self._error_response(502, str(e))) future.set_result(error_response(502, str(e)))
else: else:
log.info("Batch relay: %d requests", len(batch)) log.info("Batch relay: %d requests", len(batch))
try: try:
@@ -1965,7 +1884,7 @@ class DomainFronter:
future.set_result(result) future.set_result(result)
except Exception as e: except Exception as e:
if not future.done(): if not future.done():
future.set_result(self._error_response(502, str(e))) future.set_result(error_response(502, str(e)))
# ── Core relay with retry ───────────────────────────────────── # ── Core relay with retry ─────────────────────────────────────
@@ -2101,7 +2020,7 @@ class DomainFronter:
body=json_body, body=json_body,
) )
return self._parse_relay_response(body) return parse_relay_response(body, self._max_response_body_bytes)
async def _relay_single_h2_with_sid(self, payload: dict, async def _relay_single_h2_with_sid(self, payload: dict,
sid: str) -> bytes: sid: str) -> bytes:
@@ -2122,7 +2041,7 @@ class DomainFronter:
body=json_body, body=json_body,
) )
return self._parse_relay_response(body) return parse_relay_response(body, self._max_response_body_bytes)
async def _relay_single(self, payload: dict) -> bytes: async def _relay_single(self, payload: dict) -> bytes:
"""Execute a single relay POST → redirect → parse.""" """Execute a single relay POST → redirect → parse."""
@@ -2147,7 +2066,7 @@ class DomainFronter:
writer.write(request.encode() + json_body) writer.write(request.encode() + json_body)
await writer.drain() await writer.drain()
status, resp_headers, resp_body = await self._read_http_response(reader) status, resp_headers, resp_body = await read_http_response(reader, max_bytes=self._max_response_body_bytes)
# Follow redirect chain on the SAME connection # Follow redirect chain on the SAME connection
for _ in range(5): for _ in range(5):
@@ -2176,10 +2095,10 @@ class DomainFronter:
request = "\r\n".join(request_lines) + "\r\n\r\n" request = "\r\n".join(request_lines) + "\r\n\r\n"
writer.write(request.encode() + redirect_body) writer.write(request.encode() + redirect_body)
await writer.drain() await writer.drain()
status, resp_headers, resp_body = await self._read_http_response(reader) status, resp_headers, resp_body = await read_http_response(reader, max_bytes=self._max_response_body_bytes)
await self._release(reader, writer, created) await self._release(reader, writer, created)
return self._parse_relay_response(resp_body) return parse_relay_response(resp_body, self._max_response_body_bytes)
except Exception: except Exception:
try: try:
@@ -2230,7 +2149,7 @@ class DomainFronter:
writer.write(request.encode() + json_body) writer.write(request.encode() + json_body)
await writer.drain() await writer.drain()
status, resp_headers, resp_body = await self._read_http_response(reader) status, resp_headers, resp_body = await read_http_response(reader, max_bytes=self._max_response_body_bytes)
# Follow redirects # Follow redirects
for _ in range(5): for _ in range(5):
@@ -2258,7 +2177,7 @@ class DomainFronter:
request = "\r\n".join(request_lines) + "\r\n\r\n" request = "\r\n".join(request_lines) + "\r\n\r\n"
writer.write(request.encode() + redirect_body) writer.write(request.encode() + redirect_body)
await writer.drain() await writer.drain()
status, resp_headers, resp_body = await self._read_http_response(reader) status, resp_headers, resp_body = await read_http_response(reader, max_bytes=self._max_response_body_bytes)
await self._release(reader, writer, created) await self._release(reader, writer, created)
@@ -2297,394 +2216,6 @@ class DomainFronter:
results = [] results = []
for item in items: for item in items:
results.append(self._parse_relay_json(item)) results.append(parse_relay_json(item, self._max_response_body_bytes))
return results return results
# ── HTTP response reading (keep-alive safe) ──────────────────
async def _read_http_response(self, reader: asyncio.StreamReader):
"""Read one HTTP response. Keep-alive safe (no read-until-EOF)."""
raw = b""
while b"\r\n\r\n" not in raw:
if len(raw) > 65536: # 64 KB header size limit
return 0, {}, b""
chunk = await asyncio.wait_for(reader.read(8192), timeout=8)
if not chunk:
break
raw += chunk
if b"\r\n\r\n" not in raw:
return 0, {}, b""
header_section, body = raw.split(b"\r\n\r\n", 1)
lines = header_section.split(b"\r\n")
status_line = lines[0].decode(errors="replace")
m = re.search(r"\d{3}", status_line)
status = int(m.group()) if m else 0
headers = {}
for line in lines[1:]:
if b":" in line:
k, v = line.decode(errors="replace").split(":", 1)
headers[k.strip().lower()] = v.strip()
content_length = headers.get("content-length")
transfer_encoding = headers.get("transfer-encoding", "")
if "chunked" in transfer_encoding:
body = await self._read_chunked(reader, body)
elif content_length:
total = int(content_length)
if total > self._max_response_body_bytes:
raise RuntimeError(
"Relay response exceeds configured size cap "
f"({total} > {self._max_response_body_bytes} bytes)"
)
remaining = total - len(body)
while remaining > 0:
chunk = await asyncio.wait_for(
reader.read(min(remaining, 65536)), timeout=20
)
if not chunk:
break
body += chunk
if len(body) > self._max_response_body_bytes:
raise RuntimeError(
"Relay response exceeded configured size cap while reading body"
)
remaining -= len(chunk)
else:
# No framing — short timeout read (keep-alive safe)
while True:
try:
chunk = await asyncio.wait_for(reader.read(65536), timeout=2)
if not chunk:
break
body += chunk
if len(body) > self._max_response_body_bytes:
raise RuntimeError(
"Relay response exceeded configured size cap while streaming"
)
except asyncio.TimeoutError:
break
# Auto-decompress (gzip/deflate/br/zstd) from Google frontend
enc = headers.get("content-encoding", "")
if enc:
body = codec.decode(body, enc)
if len(body) > self._max_response_body_bytes:
raise RuntimeError(
"Decoded relay response exceeded configured size cap"
)
return status, headers, body
async def _read_chunked(self, reader, buf=b""):
"""Incrementally read chunked transfer-encoding."""
result = b""
max_body = self._max_response_body_bytes
while True:
while b"\r\n" not in buf:
data = await asyncio.wait_for(reader.read(8192), timeout=20)
if not data:
return result
buf += data
end = buf.find(b"\r\n")
size_str = buf[:end].decode(errors="replace").strip()
buf = buf[end + 2:]
if not size_str:
continue
try:
size = int(size_str, 16)
except ValueError:
break
if size == 0:
break
if size > max_body or len(result) + size > max_body:
raise RuntimeError(
"Chunked relay response exceeded configured size cap "
f"({max_body} bytes)"
)
while len(buf) < size + 2:
data = await asyncio.wait_for(reader.read(65536), timeout=20)
if not data:
result += buf[:size]
return result
buf += data
result += buf[:size]
buf = buf[size + 2:]
return result
# ── Response parsing ──────────────────────────────────────────
def _parse_relay_response(self, body: bytes) -> bytes:
"""Parse JSON from Apps Script and reconstruct an HTTP response."""
text = body.decode(errors="replace").strip()
if not text:
return self._error_response(502, "Empty response from relay")
data = self._load_relay_json(text)
if data is None:
return self._error_response(502, f"No JSON: {text[:200]}")
return self._parse_relay_json(data)
@staticmethod
def _load_relay_json(text: str) -> dict | None:
try:
return json.loads(text)
except json.JSONDecodeError:
wrapped = DomainFronter._extract_apps_script_user_html(text)
if wrapped:
data = DomainFronter._load_relay_json(wrapped)
if data is not None:
return data
match = re.search(r'\{.*\}', text, re.DOTALL)
if not match:
return None
try:
data = json.loads(match.group())
except json.JSONDecodeError:
return None
return data if isinstance(data, dict) else None
@staticmethod
def _extract_apps_script_user_html(text: str) -> str | None:
marker = 'goog.script.init("'
start = text.find(marker)
if start == -1:
return None
start += len(marker)
end = text.find('", "", undefined', start)
if end == -1:
return None
encoded = text[start:end]
try:
decoded = codecs.decode(encoded, "unicode_escape")
payload = json.loads(decoded)
except Exception:
return None
user_html = payload.get("userHtml")
return user_html if isinstance(user_html, str) else None
# ── Apps Script error classifier ─────────────────────────────
# Patterns are matched against the lower-cased raw error string from
# Apps Script's `e` field. Sources:
# • https://developers.google.com/apps-script/guides/support/troubleshooting
# • https://developers.google.com/apps-script/guides/services/quotas
# • Google Issue Tracker (urlfetch / bandwidth quota issues)
# "Service invoked too many times for one day: urlfetch."
# "Bandwidth quota exceeded"
# "UrlFetch failed because too much upload bandwidth was used"
# "UrlFetch failed because too much traffic is being sent to the specified URL"
_QUOTA_PATTERNS = (
"service invoked too many times",
"invoked too many times",
"bandwidth quota exceeded",
"too much upload bandwidth",
"too much traffic",
"urlfetch", # appears at end of the daily-quota message in all locales
"quota",
"exceeded",
"daily",
"rate limit",
)
# "Authorization is required to perform that action."
# "unauthorized" (our own Code.gs response)
# "Access denied"
# "Permission denied"
_AUTH_PATTERNS = (
"authorization is required",
"unauthorized",
"not authorized",
"permission denied",
"access denied",
)
# "Error occurred due to a missing library version or a deployment version.
# Error code Not_Found"
# "script id not found" / wrong Deployment ID
_DEPLOY_PATTERNS = (
"error code not_found",
"not_found",
"deployment",
"script id",
"scriptid",
"no script",
)
# "Server not available." / "Server error occurred, please try again."
_TRANSIENT_PATTERNS = (
"server not available",
"server error occurred",
"please try again",
"temporarily unavailable",
)
# "UrlFetch calls to <URL> are not permitted by your admin"
# "<Class> / Apiary.<Service> is disabled. Please contact your administrator"
_ADMIN_PATTERNS = (
"not permitted by your admin",
"contact your administrator",
"disabled. please contact",
"domain policy has disabled",
"administrator to enable",
)
@classmethod
def _classify_relay_error(cls, raw: str) -> str:
"""Return a human-readable explanation for a known Apps Script error.
Covers every error category documented at:
developers.google.com/apps-script/guides/support/troubleshooting
"""
lower = raw.lower()
if any(p in lower for p in cls._QUOTA_PATTERNS):
return (
"Apps Script quota exhausted. "
"Either the 20,000 URL-fetch calls/day limit or the 100 MB/day "
"bandwidth limit has been reached. "
"Wait up to 24 hours for the quota to reset, or create a second "
"Google account, deploy a fresh Apps Script there, and add its "
"script_id to config.json."
)
if any(p in lower for p in cls._AUTH_PATTERNS):
return (
"Apps Script rejected the request (auth/permission error). "
"Check: (1) AUTH_KEY in Code.gs matches 'auth_key' in config.json, "
"(2) the deployment is set to 'Execute as: Me / Anyone can access', "
"(3) you are using the Deployment ID (not the Script ID), "
"(4) the owning Google account has authorised the script by running "
"it manually at least once."
)
if any(p in lower for p in cls._DEPLOY_PATTERNS):
return (
"Apps Script deployment not found. "
"Verify 'script_id' in config.json is the Deployment ID "
"(not the Script ID), the deployment is active/not archived, "
"and you re-created the deployment after editing Code.gs."
)
if any(p in lower for p in cls._TRANSIENT_PATTERNS):
return (
"Google Apps Script server is temporarily unavailable. "
"This is a transient Google-side error — wait a moment and retry. "
f"(raw: {raw})"
)
if any(p in lower for p in cls._ADMIN_PATTERNS):
return (
"Apps Script is blocked by a Google Workspace admin policy. "
"Either the target URL is not on the admin's UrlFetch allowlist, "
"or a Google service used by the script has been disabled by the "
"domain administrator. Contact your Google Workspace admin. "
f"(raw: {raw})"
)
# Unknown — strip the leading 'Exception: ' / 'Error: ' prefix that
# Apps Script always prepends, so the message is shorter and cleaner.
cleaned = re.sub(r'^(Exception|Error):\s*', '', raw, flags=re.IGNORECASE).strip()
return f"Relay error from Apps Script: {cleaned or raw}"
def _parse_relay_json(self, data: dict) -> bytes:
"""Convert a parsed relay JSON dict to raw HTTP response bytes."""
if "e" in data:
raw_err = str(data["e"])
friendly = self._classify_relay_error(raw_err)
log.warning("Apps Script error — %s | raw: %s", friendly.split(".")[0], raw_err)
return self._error_response(502, friendly)
status = data.get("s", 200)
resp_headers = data.get("h", {})
resp_body = base64.b64decode(data.get("b", ""))
if len(resp_body) > self._max_response_body_bytes:
return self._error_response(
502,
"Relay response exceeds cap "
f"({self._max_response_body_bytes} bytes). "
"Increase max_response_body_bytes if your system has enough RAM.",
)
status_text = {200: "OK", 206: "Partial Content",
301: "Moved", 302: "Found", 304: "Not Modified",
400: "Bad Request", 403: "Forbidden", 404: "Not Found",
500: "Internal Server Error"}.get(status, "OK")
result = f"HTTP/1.1 {status} {status_text}\r\n"
skip = {"transfer-encoding", "connection", "keep-alive",
"content-length", "content-encoding"}
for k, v in resp_headers.items():
if k.lower() in skip:
continue
# Apps Script returns multi-valued headers (e.g. Set-Cookie) as a
# JavaScript array. Emit each value as its own header line.
# A single string that holds multiple Set-Cookie values joined
# with ", " also needs to be split, otherwise the browser sees
# one malformed cookie and sites like x.com fail.
values = v if isinstance(v, list) else [v]
if k.lower() == "set-cookie":
expanded = []
for item in values:
expanded.extend(self._split_set_cookie(str(item)))
values = expanded
for val in values:
result += f"{k}: {val}\r\n"
result += f"Content-Length: {len(resp_body)}\r\n"
result += "\r\n"
return result.encode() + resp_body
@staticmethod
def _split_set_cookie(blob: str) -> list[str]:
"""Split a Set-Cookie string that may contain multiple cookies.
Apps Script sometimes joins multiple Set-Cookie values with ", ",
which collides with the comma that legitimately appears inside the
`Expires` attribute (e.g. "Expires=Wed, 21 Oct 2026 ..."). We split
only on commas that are immediately followed by a cookie name=value
pair (token '=' ...), leaving date commas intact.
"""
if not blob:
return []
# Split on ", " but only when the following text looks like the start
# of a new cookie (a token followed by '=').
parts = re.split(r",\s*(?=[A-Za-z0-9!#$%&'*+\-.^_`|~]+=)", blob)
return [p.strip() for p in parts if p.strip()]
def _split_raw_response(self, raw: bytes):
"""Split a raw HTTP response into (status, headers_dict, body)."""
if b"\r\n\r\n" not in raw:
return 0, {}, raw
header_section, body = raw.split(b"\r\n\r\n", 1)
lines = header_section.split(b"\r\n")
m = re.search(r"\d{3}", lines[0].decode(errors="replace"))
status = int(m.group()) if m else 0
headers = {}
for line in lines[1:]:
if b":" in line:
k, v = line.decode(errors="replace").split(":", 1)
headers[k.strip().lower()] = v.strip()
return status, headers, body
def _error_response(self, status: int, message: str) -> bytes:
body = f"<html><body><h1>{status}</h1><p>{message}</p></body></html>"
return (
f"HTTP/1.1 {status} Error\r\n"
f"Content-Type: text/html\r\n"
f"Content-Length: {len(body)}\r\n"
f"\r\n"
f"{body}"
).encode()
+146
View File
@@ -0,0 +1,146 @@
"""
Domain-fronting helper utilities: SNI pool building, range-request validation,
progress formatting, and stream spool read/write helpers.
Extracted from domain_fronter.py to separate pure helper logic from the
DomainFronter class.
"""
import re
from dataclasses import dataclass
from core.constants import FRONT_SNI_POOL_GOOGLE
__all__ = [
"HostStat",
"build_sni_pool",
"parse_content_range",
"validate_range_response",
"format_bytes_human",
"format_elapsed_short",
"render_progress_bar",
"progress_line",
"spool_write",
"spool_read",
]
@dataclass
class HostStat:
"""Per-host traffic accounting — useful for profiling slow / heavy sites."""
requests: int = 0
cache_hits: int = 0
bytes: int = 0
total_latency_ns: int = 0
errors: int = 0
def build_sni_pool(front_domain: str, overrides: list | None) -> list[str]:
"""Build the list of SNIs to rotate through on new outbound TLS handshakes."""
if overrides:
seen: set[str] = set()
out: list[str] = []
for item in overrides:
host = str(item).strip().lower().rstrip(".")
if host and host not in seen:
seen.add(host)
out.append(host)
if out:
return out
front_domain = (front_domain or "").lower().rstrip(".")
if front_domain.endswith(".google.com") or front_domain == "google.com":
pool = list(FRONT_SNI_POOL_GOOGLE)
if front_domain and front_domain not in pool:
pool.insert(0, front_domain)
return pool
return [front_domain] if front_domain else ["www.google.com"]
def parse_content_range(value: str) -> tuple[int, int, int] | None:
match = re.match(r"^\s*bytes\s+(\d+)-(\d+)/(\d+)\s*$", value or "")
if not match:
return None
start, end, total = (int(group) for group in match.groups())
if start < 0 or end < start or total <= end:
return None
return start, end, total
def validate_range_response(
status: int,
resp_headers: dict,
body: bytes,
start_off: int,
end_off: int,
total_size: int | None = None,
) -> str | None:
if status != 206:
return f"status {status}"
parsed = parse_content_range(resp_headers.get("content-range", ""))
if not parsed:
return "missing/invalid Content-Range"
got_start, got_end, got_total = parsed
if got_start != start_off or got_end != end_off:
return f"Content-Range mismatch {got_start}-{got_end}"
if total_size is not None and got_total != total_size:
return f"Content-Range total mismatch {got_total}/{total_size}"
expected = end_off - start_off + 1
if len(body) != expected:
return f"short chunk {len(body)}/{expected} B"
return None
def format_bytes_human(num_bytes: int) -> str:
value = float(max(0, num_bytes))
units = ("B", "KiB", "MiB", "GiB", "TiB")
unit = units[0]
for unit in units:
if value < 1024.0 or unit == units[-1]:
break
value /= 1024.0
if unit == "B":
return f"{int(value)} {unit}"
return f"{value:.1f} {unit}"
def format_elapsed_short(seconds: float) -> str:
total = max(0, int(seconds))
minutes, secs = divmod(total, 60)
hours, minutes = divmod(minutes, 60)
if hours:
return f"{hours:02d}:{minutes:02d}:{secs:02d}"
return f"{minutes:02d}:{secs:02d}"
def render_progress_bar(done: int, total: int, width: int = 34) -> str:
if total <= 0:
return "[" + ("-" * width) + "]"
ratio = max(0.0, min(1.0, done / total))
filled = min(width, int(round(ratio * width)))
return "[" + ("#" * filled) + ("-" * (width - filled)) + "]"
def progress_line(*, elapsed: float, done: int, total: int, speed_bytes_per_sec: float) -> str:
return (
f"[{format_elapsed_short(elapsed)}] "
f"{render_progress_bar(done, total)} "
f"{format_bytes_human(done)} / {format_bytes_human(total)} "
f"({format_bytes_human(int(speed_bytes_per_sec))}/s)"
)
# ── Parallel-range spool helpers ─────────────────────────────────────────────
def spool_write(file_obj, offset: int, data: bytes) -> None:
"""Write *data* at *offset* in a temp file used for parallel-range spooling."""
file_obj.seek(offset)
file_obj.write(data)
file_obj.flush()
def spool_read(file_obj, offset: int, size: int) -> bytes:
"""Read *size* bytes from *offset* in a parallel-range spool file."""
file_obj.seek(offset)
return file_obj.read(size)
@@ -25,7 +25,7 @@ try:
except Exception: # optional dependency fallback except Exception: # optional dependency fallback
certifi = None certifi = None
import codec from core import codec
log = logging.getLogger("H2") log = logging.getLogger("H2")
+163
View File
@@ -0,0 +1,163 @@
"""
HTTP/1.1 response reader for keep-alive connections.
Reads exactly one HTTP response from an asyncio StreamReader, handling
chunked transfer-encoding, Content-Length framing, and streaming bodies.
Auto-decompresses the response body according to the Content-Encoding
header (gzip, deflate, brotli, zstd).
Usage::
status, headers, body = await read_http_response(reader, max_bytes=50_000_000)
"""
from __future__ import annotations
import asyncio
import re
from core import codec
__all__ = ["read_http_response"]
async def read_http_response(
reader: asyncio.StreamReader,
*,
max_bytes: int,
) -> tuple[int, dict[str, str], bytes]:
"""Read one HTTP/1.1 response. Keep-alive safe (no read-until-EOF).
Args:
reader: An ``asyncio.StreamReader`` positioned at the start of
an HTTP response.
max_bytes: Hard cap on the decompressed body size. Raises
``RuntimeError`` if exceeded.
Returns:
A ``(status_code, headers, body)`` triple. ``status_code`` is 0
and the other fields are empty/empty if the response is malformed.
"""
# ── Read until header boundary ────────────────────────────────
raw = b""
while b"\r\n\r\n" not in raw:
if len(raw) > 65536: # 64 KB header size limit
return 0, {}, b""
chunk = await asyncio.wait_for(reader.read(8192), timeout=8)
if not chunk:
break
raw += chunk
if b"\r\n\r\n" not in raw:
return 0, {}, b""
header_section, body = raw.split(b"\r\n\r\n", 1)
lines = header_section.split(b"\r\n")
status_line = lines[0].decode(errors="replace")
m = re.search(r"\d{3}", status_line)
status = int(m.group()) if m else 0
headers: dict[str, str] = {}
for line in lines[1:]:
if b":" in line:
k, v = line.decode(errors="replace").split(":", 1)
headers[k.strip().lower()] = v.strip()
# ── Body framing ──────────────────────────────────────────────
content_length = headers.get("content-length")
transfer_encoding = headers.get("transfer-encoding", "")
if "chunked" in transfer_encoding:
body = await _read_chunked(reader, body, max_bytes=max_bytes)
elif content_length:
total = int(content_length)
if total > max_bytes:
raise RuntimeError(
"Relay response exceeds configured size cap "
f"({total} > {max_bytes} bytes)"
)
remaining = total - len(body)
while remaining > 0:
chunk = await asyncio.wait_for(
reader.read(min(remaining, 65536)), timeout=20
)
if not chunk:
break
body += chunk
if len(body) > max_bytes:
raise RuntimeError(
"Relay response exceeded configured size cap while reading body"
)
remaining -= len(chunk)
else:
# No framing — short timeout read (keep-alive safe)
while True:
try:
chunk = await asyncio.wait_for(reader.read(65536), timeout=2)
if not chunk:
break
body += chunk
if len(body) > max_bytes:
raise RuntimeError(
"Relay response exceeded configured size cap while streaming"
)
except asyncio.TimeoutError:
break
# ── Auto-decompress ───────────────────────────────────────────
enc = headers.get("content-encoding", "")
if enc:
body = codec.decode(body, enc)
if len(body) > max_bytes:
raise RuntimeError(
"Decoded relay response exceeded configured size cap"
)
return status, headers, body
async def _read_chunked(
reader: asyncio.StreamReader,
buf: bytes = b"",
*,
max_bytes: int,
) -> bytes:
"""Incrementally read a chunked-transfer-encoded body."""
result = b""
while True:
while b"\r\n" not in buf:
data = await asyncio.wait_for(reader.read(8192), timeout=20)
if not data:
return result
buf += data
end = buf.find(b"\r\n")
size_str = buf[:end].decode(errors="replace").strip()
buf = buf[end + 2:]
if not size_str:
continue
try:
size = int(size_str, 16)
except ValueError:
break
if size == 0:
break
if size > max_bytes or len(result) + size > max_bytes:
raise RuntimeError(
"Chunked relay response exceeded configured size cap "
f"({max_bytes} bytes)"
)
while len(buf) < size + 2:
data = await asyncio.wait_for(reader.read(65536), timeout=20)
if not data:
result += buf[:size]
return result
buf += data
result += buf[:size]
buf = buf[size + 2:]
return result
+323
View File
@@ -0,0 +1,323 @@
"""
Apps Script relay response parsing.
Pure functions for decoding the JSON envelope returned by Code.gs and
reconstructing a standard HTTP response that the proxy can forward to
the client browser.
Public API
----------
parse_relay_response(body, max_body_bytes) -> bytes
Top-level entry point: bytes → raw HTTP response bytes.
split_raw_response(raw) -> (status, headers, body)
Parse a raw HTTP byte string into its parts.
error_response(status, message) -> bytes
Build a minimal HTML error response.
classify_relay_error(raw) -> str
Map a raw Apps Script error string to a human-readable explanation.
"""
import base64
import codecs
import json
import logging
import re
log = logging.getLogger("Fronter")
__all__ = [
"classify_relay_error",
"error_response",
"split_raw_response",
"split_set_cookie",
"parse_relay_json",
"extract_apps_script_user_html",
"load_relay_json",
"parse_relay_response",
]
# ── Apps Script error pattern tables ─────────────────────────────────────────
# Matched against the lower-cased ``e`` field returned by Code.gs.
# Sources:
# • https://developers.google.com/apps-script/guides/support/troubleshooting
# • https://developers.google.com/apps-script/guides/services/quotas
# "Service invoked too many times for one day: urlfetch."
# "Bandwidth quota exceeded"
_QUOTA_PATTERNS = (
"service invoked too many times",
"invoked too many times",
"bandwidth quota exceeded",
"too much upload bandwidth",
"too much traffic",
"urlfetch", # appears at end of the daily-quota message in all locales
"quota",
"exceeded",
"daily",
"rate limit",
)
# "Authorization is required to perform that action."
# "unauthorized" (our own Code.gs response)
_AUTH_PATTERNS = (
"authorization is required",
"unauthorized",
"not authorized",
"permission denied",
"access denied",
)
# "Error occurred due to a missing library version or a deployment version.
# Error code Not_Found"
# "script id not found" / wrong Deployment ID
_DEPLOY_PATTERNS = (
"error code not_found",
"not_found",
"deployment",
"script id",
"scriptid",
"no script",
)
# "Server not available." / "Server error occurred, please try again."
_TRANSIENT_PATTERNS = (
"server not available",
"server error occurred",
"please try again",
"temporarily unavailable",
)
# "UrlFetch calls to <URL> are not permitted by your admin"
# "<Class> / Apiary.<Service> is disabled. Please contact your administrator"
_ADMIN_PATTERNS = (
"not permitted by your admin",
"contact your administrator",
"disabled. please contact",
"domain policy has disabled",
"administrator to enable",
)
# ── Error classifier ──────────────────────────────────────────────────────────
def classify_relay_error(raw: str) -> str:
"""Return a human-readable explanation for a known Apps Script error.
Covers every error category documented at:
developers.google.com/apps-script/guides/support/troubleshooting
"""
lower = raw.lower()
if any(p in lower for p in _QUOTA_PATTERNS):
return (
"Apps Script quota exhausted. "
"Either the 20,000 URL-fetch calls/day limit or the 100 MB/day "
"bandwidth limit has been reached. "
"Wait up to 24 hours for the quota to reset, or create a second "
"Google account, deploy a fresh Apps Script there, and add its "
"script_id to config.json."
)
if any(p in lower for p in _AUTH_PATTERNS):
return (
"Apps Script rejected the request (auth/permission error). "
"Check: (1) AUTH_KEY in Code.gs matches 'auth_key' in config.json, "
"(2) the deployment is set to 'Execute as: Me / Anyone can access', "
"(3) you are using the Deployment ID (not the Script ID), "
"(4) the owning Google account has authorised the script by running "
"it manually at least once."
)
if any(p in lower for p in _DEPLOY_PATTERNS):
return (
"Apps Script deployment not found. "
"Verify 'script_id' in config.json is the Deployment ID "
"(not the Script ID), the deployment is active/not archived, "
"and you re-created the deployment after editing Code.gs."
)
if any(p in lower for p in _TRANSIENT_PATTERNS):
return (
"Google Apps Script server is temporarily unavailable. "
"This is a transient Google-side error — wait a moment and retry. "
f"(raw: {raw})"
)
if any(p in lower for p in _ADMIN_PATTERNS):
return (
"Apps Script is blocked by a Google Workspace admin policy. "
"Either the target URL is not on the admin's UrlFetch allowlist, "
"or a Google service used by the script has been disabled by the "
"domain administrator. Contact your Google Workspace admin. "
f"(raw: {raw})"
)
# Unknown — strip the leading 'Exception: ' / 'Error: ' prefix that
# Apps Script always prepends, so the message is shorter and cleaner.
cleaned = re.sub(r'^(Exception|Error):\s*', '', raw, flags=re.IGNORECASE).strip()
return f"Relay error from Apps Script: {cleaned or raw}"
# ── Low-level HTTP helpers ────────────────────────────────────────────────────
def error_response(status: int, message: str) -> bytes:
"""Build a minimal HTML error response."""
body = f"<html><body><h1>{status}</h1><p>{message}</p></body></html>"
return (
f"HTTP/1.1 {status} Error\r\n"
f"Content-Type: text/html\r\n"
f"Content-Length: {len(body)}\r\n"
f"\r\n"
f"{body}"
).encode()
def split_raw_response(raw: bytes):
"""Split a raw HTTP response into ``(status, headers_dict, body)``."""
if b"\r\n\r\n" not in raw:
return 0, {}, raw
header_section, body = raw.split(b"\r\n\r\n", 1)
lines = header_section.split(b"\r\n")
m = re.search(r"\d{3}", lines[0].decode(errors="replace"))
status = int(m.group()) if m else 0
headers: dict[str, str] = {}
for line in lines[1:]:
if b":" in line:
k, v = line.decode(errors="replace").split(":", 1)
headers[k.strip().lower()] = v.strip()
return status, headers, body
def split_set_cookie(blob: str) -> list[str]:
"""Split a Set-Cookie string that may contain multiple cookies.
Apps Script sometimes joins multiple Set-Cookie values with ", ",
which collides with the comma that legitimately appears inside the
``Expires`` attribute (e.g. "Expires=Wed, 21 Oct 2026 ..."). We split
only on commas that are immediately followed by a cookie name=value
pair, leaving date commas intact.
"""
if not blob:
return []
parts = re.split(r",\s*(?=[A-Za-z0-9!#$%&'*+\-.^_`|~]+=)", blob)
return [p.strip() for p in parts if p.strip()]
# ── JSON → HTTP response ─────────────────────────────────────────────────────
def parse_relay_json(data: dict, max_body_bytes: int) -> bytes:
"""Convert a parsed relay JSON dict to raw HTTP response bytes."""
if "e" in data:
raw_err = str(data["e"])
friendly = classify_relay_error(raw_err)
log.warning("Apps Script error — %s | raw: %s", friendly.split(".")[0], raw_err)
return error_response(502, friendly)
status = data.get("s", 200)
resp_headers = data.get("h", {})
resp_body = base64.b64decode(data.get("b", ""))
if len(resp_body) > max_body_bytes:
return error_response(
502,
f"Relay response exceeds cap ({max_body_bytes} bytes). "
"Increase max_response_body_bytes if your system has enough RAM.",
)
status_text = {
200: "OK", 206: "Partial Content",
301: "Moved", 302: "Found", 304: "Not Modified",
400: "Bad Request", 403: "Forbidden", 404: "Not Found",
500: "Internal Server Error",
}.get(status, "OK")
result = f"HTTP/1.1 {status} {status_text}\r\n"
skip = {"transfer-encoding", "connection", "keep-alive",
"content-length", "content-encoding"}
for k, v in resp_headers.items():
if k.lower() in skip:
continue
# Apps Script returns multi-valued headers (e.g. Set-Cookie) as a
# JavaScript array. Emit each value as its own header line.
# A single string that holds multiple Set-Cookie values joined
# with ", " also needs to be split, otherwise the browser sees
# one malformed cookie and sites like x.com fail.
values = v if isinstance(v, list) else [v]
if k.lower() == "set-cookie":
expanded: list[str] = []
for item in values:
expanded.extend(split_set_cookie(str(item)))
values = expanded
for val in values:
result += f"{k}: {val}\r\n"
result += f"Content-Length: {len(resp_body)}\r\n"
result += "\r\n"
return result.encode() + resp_body
def extract_apps_script_user_html(text: str) -> str | None:
"""Extract embedded user HTML from an Apps Script HTML-page response."""
marker = 'goog.script.init("'
start = text.find(marker)
if start == -1:
return None
start += len(marker)
end = text.find('", "", undefined', start)
if end == -1:
return None
encoded = text[start:end]
try:
decoded = codecs.decode(encoded, "unicode_escape")
payload = json.loads(decoded)
except Exception:
return None
user_html = payload.get("userHtml")
return user_html if isinstance(user_html, str) else None
def load_relay_json(text: str) -> dict | None:
"""Parse a relay JSON body, handling Apps Script HTML wrappers."""
try:
return json.loads(text)
except json.JSONDecodeError:
wrapped = extract_apps_script_user_html(text)
if wrapped:
data = load_relay_json(wrapped)
if data is not None:
return data
match = re.search(r'\{.*\}', text, re.DOTALL)
if not match:
return None
try:
data = json.loads(match.group())
except json.JSONDecodeError:
return None
return data if isinstance(data, dict) else None
def parse_relay_response(body: bytes, max_body_bytes: int) -> bytes:
"""Parse a raw Apps Script response body into a raw HTTP response.
``body`` is the bytes returned over the TLS connection after stripping
the outer HTTP/1.1 response headers. The function:
1. Decodes the JSON envelope produced by Code.gs.
2. Unpacks the nested status / headers / base64-body fields.
3. Reconstructs a well-formed HTTP/1.1 response suitable for
forwarding directly to the browser.
"""
text = body.decode(errors="replace").strip()
if not text:
return error_response(502, "Empty response from relay")
data = load_relay_json(text)
if data is None:
return error_response(502, f"No JSON: {text[:200]}")
return parse_relay_json(data, max_body_bytes)