diff --git a/config.example.json b/config.example.json index 387a0ed..9f0d46d 100644 --- a/config.example.json +++ b/config.example.json @@ -7,8 +7,15 @@ "auth_key": "CHANGE_ME_TO_A_STRONG_SECRET", "listen_host": "127.0.0.1", "listen_port": 8085, + "socks5_enabled": true, + "socks5_host": "127.0.0.1", + "socks5_port": 1080, "log_level": "INFO", "verify_ssl": true, + "_direct_google_exclude_comment": "Google web apps that should NEVER use the raw direct-tunnel shortcut. Supports exact hosts and optional suffix patterns like \".googleapis.com\". They will go through the MITM relay path instead for better compatibility.", + "direct_google_exclude": ["gemini.google.com", "aistudio.google.com", "notebooklm.google.com", "labs.google.com", "meet.google.com", "accounts.google.com", "ogs.google.com", "mail.google.com", "calendar.google.com", "drive.google.com", "docs.google.com", "chat.google.com"], + "_direct_google_allow_comment": "Conservative allowlist for raw direct Google tunneling. Leave empty unless you have confirmed a host works better direct than via relay.", + "direct_google_allow": ["www.google.com", "safebrowsing.google.com"], "_hosts_comment": "Optional SNI-rewrite overrides. YouTube, googlevideo, gstatic, fonts.googleapis.com, ytimg, ggpht, doubleclick, etc. are ALREADY handled automatically (routed via google_ip with SNI=front_domain, same trick as the Xray MITM-DomainFronting config). Add entries here only for custom domains, e.g. \"example.com\": \"216.239.38.120\".", "hosts": {} } diff --git a/domain_fronter.py b/domain_fronter.py index a623cab..862c1ac 100644 --- a/domain_fronter.py +++ b/domain_fronter.py @@ -19,6 +19,7 @@ Mode 4 (apps_script): import asyncio import base64 +import hashlib import gzip import json import logging @@ -34,6 +35,12 @@ log = logging.getLogger("Fronter") class DomainFronter: + _STATIC_EXTS = ( + ".css", ".js", ".mjs", ".woff", ".woff2", ".ttf", ".eot", + ".png", ".jpg", ".jpeg", ".gif", ".webp", ".svg", ".ico", + ".mp3", ".mp4", ".webm", ".wasm", ".avif", + ) + def __init__(self, config: dict): mode = config.get("mode", "domain_fronting") @@ -170,9 +177,34 @@ class DomainFronter: self._script_idx += 1 return sid - def _exec_path(self) -> str: - """Get the next Apps Script endpoint path (/dev or /exec).""" - sid = self._next_script_id() + @staticmethod + def _host_key(url_or_host: str | None) -> str: + """Return a stable routing key for a URL or host string.""" + if not url_or_host: + return "" + parsed = urlparse(url_or_host if "://" in url_or_host else f"https://{url_or_host}") + host = parsed.hostname or url_or_host + return host.lower().rstrip(".") + + def _script_id_for_key(self, key: str | None = None) -> str: + """Pick a stable Apps Script ID for a host or fallback to round-robin. + + When multiple deployments are configured, using a stable mapping per + host reduces IP/session churn for sites that are sensitive to endpoint + changes. If no key is available, we keep the older round-robin fallback + so warmup/keepalive traffic still distributes normally. + """ + if len(self._script_ids) == 1: + return self._script_ids[0] + if not key: + return self._next_script_id() + digest = hashlib.sha1(key.encode("utf-8")).digest() + idx = int.from_bytes(digest[:4], "big") % len(self._script_ids) + return self._script_ids[idx] + + def _exec_path(self, url_or_host: str | None = None) -> str: + """Get the Apps Script endpoint path (/dev or /exec).""" + sid = self._script_id_for_key(self._host_key(url_or_host)) return f"/macros/s/{sid}/{'dev' if self._dev_available else 'exec'}" async def _flush_pool(self): @@ -332,7 +364,7 @@ class DomainFronter: # Apps Script keepalive — warm the container payload = {"m": "HEAD", "u": "http://example.com/", "k": self.auth_key} - path = self._exec_path() + path = self._exec_path("example.com") t0 = time.perf_counter() await asyncio.wait_for( self._h2.request( @@ -514,6 +546,11 @@ class DomainFronter: payload = self._build_payload(method, url, headers, body) + # Stateful/browser-navigation requests should preserve exact ordering + # and header context; batching/coalescing is reserved for static fetches. + if self._is_stateful_request(method, url, headers, body): + return await self._relay_with_retry(payload) + # Coalesce concurrent GETs for the same URL. # CRITICAL: do NOT coalesce when a Range header is present — # parallel range downloads MUST each hit the server independently. @@ -711,7 +748,8 @@ class DomainFronter: payload = { "m": method, "u": url, - "r": True, + # Let the browser/app see origin redirects and cookies directly. + "r": False, } if headers: # Strip Accept-Encoding: Apps Script auto-decompresses gzip @@ -726,6 +764,46 @@ class DomainFronter: payload["ct"] = ct return payload + @classmethod + def _is_static_asset_url(cls, url: str) -> bool: + path = urlparse(url).path.lower() + return any(path.endswith(ext) for ext in cls._STATIC_EXTS) + + @staticmethod + def _header_value(headers: dict | None, name: str) -> str: + if not headers: + return "" + for key, value in headers.items(): + if key.lower() == name: + return str(value) + return "" + + @classmethod + def _is_stateful_request(cls, method: str, url: str, + headers: dict | None, body: bytes) -> bool: + method = method.upper() + if method not in {"GET", "HEAD"} or body: + return True + + if headers: + for name in ( + "cookie", "authorization", "proxy-authorization", + "origin", "referer", "if-none-match", "if-modified-since", + "cache-control", "pragma", + ): + if cls._header_value(headers, name): + return True + + accept = cls._header_value(headers, "accept").lower() + if "text/html" in accept or "application/json" in accept: + return True + + fetch_mode = cls._header_value(headers, "sec-fetch-mode").lower() + if fetch_mode in {"navigate", "cors"}: + return True + + return not cls._is_static_asset_url(url) + # ── Batch collector ─────────────────────────────────────────── async def _batch_submit(self, payload: dict) -> bytes: @@ -866,7 +944,7 @@ class DomainFronter: full_payload["k"] = self.auth_key json_body = json.dumps(full_payload).encode() - path = self._exec_path() + path = self._exec_path(payload.get("u")) status, headers, body = await self._h2.request( method="POST", path=path, host=self.http_host, @@ -883,7 +961,7 @@ class DomainFronter: full_payload["k"] = self.auth_key json_body = json.dumps(full_payload).encode() - path = self._exec_path() + path = self._exec_path(payload.get("u")) reader, writer, created = await self._acquire() try: @@ -911,14 +989,22 @@ class DomainFronter: parsed = urlparse(location) rpath = parsed.path + ("?" + parsed.query if parsed.query else "") - request = ( - f"GET {rpath} HTTP/1.1\r\n" - f"Host: {parsed.netloc}\r\n" - f"Accept-Encoding: gzip\r\n" - f"Connection: keep-alive\r\n" - f"\r\n" - ) - writer.write(request.encode()) + if status in (307, 308): + redirect_method = "POST" + redirect_body = json_body + else: + redirect_method = "GET" + redirect_body = b"" + request_lines = [ + f"{redirect_method} {rpath} HTTP/1.1", + f"Host: {parsed.netloc}", + "Accept-Encoding: gzip", + "Connection: keep-alive", + ] + if redirect_body: + request_lines.append(f"Content-Length: {len(redirect_body)}") + request = "\r\n".join(request_lines) + "\r\n\r\n" + writer.write(request.encode() + redirect_body) await writer.drain() status, resp_headers, resp_body = await self._read_http_response(reader) @@ -939,7 +1025,7 @@ class DomainFronter: "q": payloads, } json_body = json.dumps(batch_payload).encode() - path = self._exec_path() + path = self._exec_path(payloads[0].get("u") if payloads else None) # Try HTTP/2 first if self._h2 and self._h2.is_connected: @@ -983,14 +1069,22 @@ class DomainFronter: break parsed = urlparse(location) rpath = parsed.path + ("?" + parsed.query if parsed.query else "") - request = ( - f"GET {rpath} HTTP/1.1\r\n" - f"Host: {parsed.netloc}\r\n" - f"Accept-Encoding: gzip\r\n" - f"Connection: keep-alive\r\n" - f"\r\n" - ) - writer.write(request.encode()) + if status in (307, 308): + redirect_method = "POST" + redirect_body = json_body + else: + redirect_method = "GET" + redirect_body = b"" + request_lines = [ + f"{redirect_method} {rpath} HTTP/1.1", + f"Host: {parsed.netloc}", + "Accept-Encoding: gzip", + "Connection: keep-alive", + ] + if redirect_body: + request_lines.append(f"Content-Length: {len(redirect_body)}") + request = "\r\n".join(request_lines) + "\r\n\r\n" + writer.write(request.encode() + redirect_body) await writer.drain() status, resp_headers, resp_body = await self._read_http_response(reader) diff --git a/main.py b/main.py index d80add6..a191b58 100644 --- a/main.py +++ b/main.py @@ -51,6 +51,17 @@ def parse_args(): default=None, help="Override listen host (env: DFT_HOST)", ) + parser.add_argument( + "--socks5-port", + type=int, + default=None, + help="Override SOCKS5 listen port (env: DFT_SOCKS5_PORT)", + ) + parser.add_argument( + "--disable-socks5", + action="store_true", + help="Disable the built-in SOCKS5 listener.", + ) parser.add_argument( "--log-level", choices=["DEBUG", "INFO", "WARNING", "ERROR"], @@ -107,6 +118,14 @@ def main(): elif os.environ.get("DFT_HOST"): config["listen_host"] = os.environ["DFT_HOST"] + if args.socks5_port is not None: + config["socks5_port"] = args.socks5_port + elif os.environ.get("DFT_SOCKS5_PORT"): + config["socks5_port"] = int(os.environ["DFT_SOCKS5_PORT"]) + + if args.disable_socks5: + config["socks5_enabled"] = False + if args.log_level is not None: config["log_level"] = args.log_level elif os.environ.get("DFT_LOG_LEVEL"): @@ -162,7 +181,7 @@ def main(): config.get("front_domain", "www.google.com")) script_ids = config.get("script_ids") or config.get("script_id") if isinstance(script_ids, list): - log.info("Script IDs : %d scripts (round-robin)", len(script_ids)) + log.info("Script IDs : %d scripts (sticky per-host)", len(script_ids)) for i, sid in enumerate(script_ids): log.info(" [%d] %s", i + 1, sid) else: @@ -192,7 +211,13 @@ def main(): log.info("Front domain (SNI) : %s", config.get("front_domain", "?")) log.info("Worker host (Host) : %s", config.get("worker_host", "?")) - log.info("Proxy address : %s:%d", config.get("listen_host", "127.0.0.1"), config.get("listen_port", 8080)) + log.info("HTTP proxy : %s:%d", + config.get("listen_host", "127.0.0.1"), + config.get("listen_port", 8080)) + if config.get("socks5_enabled", True): + log.info("SOCKS5 proxy : %s:%d", + config.get("socks5_host", config.get("listen_host", "127.0.0.1")), + config.get("socks5_port", 1080)) try: asyncio.run(ProxyServer(config).start()) diff --git a/proxy_server.py b/proxy_server.py index fb9b1d8..ea9117c 100644 --- a/proxy_server.py +++ b/proxy_server.py @@ -12,8 +12,11 @@ Supports: import asyncio import logging import re +import socket import ssl import time +import ipaddress +from urllib.parse import urlparse from domain_fronter import DomainFronter @@ -69,7 +72,7 @@ class ResponseCache: # Don't cache errors or non-200 if b"HTTP/1.1 200" not in raw_response[:20]: return 0 - if "no-store" in hdr: + if "no-store" in hdr or "private" in hdr or "set-cookie:" in hdr: return 0 # Explicit max-age @@ -101,13 +104,57 @@ class ResponseCache: class ProxyServer: + _GOOGLE_DIRECT_EXACT_EXCLUDE = { + "gemini.google.com", + "aistudio.google.com", + "notebooklm.google.com", + "labs.google.com", + "meet.google.com", + "accounts.google.com", + "ogs.google.com", + "mail.google.com", + "calendar.google.com", + "drive.google.com", + "docs.google.com", + "chat.google.com", + "photos.google.com", + "maps.google.com", + "myaccount.google.com", + "contacts.google.com", + "classroom.google.com", + "keep.google.com", + "play.google.com", + } + _GOOGLE_DIRECT_SUFFIX_EXCLUDE = ( + ".meet.google.com", + ) + _GOOGLE_DIRECT_ALLOW_EXACT = { + "www.google.com", + "google.com", + "safebrowsing.google.com", + } + _GOOGLE_DIRECT_ALLOW_SUFFIXES = () + _TRACE_HOST_SUFFIXES = ( + "chatgpt.com", + "openai.com", + "gemini.google.com", + "google.com", + "cloudflare.com", + "challenges.cloudflare.com", + "turnstile", + ) + def __init__(self, config: dict): self.host = config.get("listen_host", "127.0.0.1") self.port = config.get("listen_port", 8080) + self.socks_enabled = config.get("socks5_enabled", True) + self.socks_host = config.get("socks5_host", self.host) + self.socks_port = config.get("socks5_port", 1080) self.mode = config.get("mode", "domain_fronting") self.fronter = DomainFronter(config) self.mitm = None self._cache = ResponseCache(max_mb=50) + self._direct_fail_until: dict[str, float] = {} # Persistent HTTP tunnel cache for google_fronting mode # Key: "host:port" → (tunnel_reader, tunnel_writer, lock) @@ -117,6 +164,22 @@ class ProxyServer: # hosts override — DNS fake-map: domain/suffix → IP # Checked before any real DNS lookup; supports exact and suffix matching. self._hosts: dict[str, str] = config.get("hosts", {}) + configured_direct_exclude = config.get("direct_google_exclude", []) + self._direct_google_exclude = { + h.lower().rstrip(".") + for h in ( + list(self._GOOGLE_DIRECT_EXACT_EXCLUDE) + + list(configured_direct_exclude) + ) + } + configured_direct_allow = config.get("direct_google_allow", []) + self._direct_google_allow = { + h.lower().rstrip(".") + for h in ( + list(self._GOOGLE_DIRECT_ALLOW_EXACT) + + list(configured_direct_allow) + ) + } if self.mode == "apps_script": try: @@ -127,14 +190,94 @@ class ProxyServer: log.error("Run: pip install cryptography") raise SystemExit(1) + @staticmethod + def _header_value(headers: dict | None, name: str) -> str: + if not headers: + return "" + for key, value in headers.items(): + if key.lower() == name: + return str(value) + return "" + + def _cache_allowed(self, method: str, url: str, + headers: dict | None, body: bytes) -> bool: + if method.upper() != "GET" or body: + return False + for name in ( + "cookie", "authorization", "proxy-authorization", "range", + "if-none-match", "if-modified-since", "cache-control", "pragma", + ): + if self._header_value(headers, name): + return False + return self.fronter._is_static_asset_url(url) + + @classmethod + def _should_trace_host(cls, host: str) -> bool: + h = host.lower().rstrip(".") + return any( + token == h or token in h or h.endswith("." + token) + for token in cls._TRACE_HOST_SUFFIXES + ) + + def _log_response_summary(self, url: str, response: bytes): + status, headers, body = self.fronter._split_raw_response(response) + host = (urlparse(url).hostname or "").lower() + if status >= 300 or self._should_trace_host(host): + location = headers.get("location", "") + server = headers.get("server", "") + cf_ray = headers.get("cf-ray", "") + content_type = headers.get("content-type", "") + body_len = len(body) + body_hint = "-" + if "text/html" in content_type.lower() and body: + sample = body[:800].decode(errors="replace").lower() + if "