diff --git a/src/core/adblock.py b/src/core/adblock.py index 309d65a..dadc544 100644 --- a/src/core/adblock.py +++ b/src/core/adblock.py @@ -36,7 +36,9 @@ _DEFAULT_MAX_AGE = 86_400 # 24 hours _DOWNLOAD_TIMEOUT = 30 # seconds per HTTP request # Cache sits next to the project root (same dir as main.py / config.json). -_CACHE_DIR = pathlib.Path("adblock_cache") +# Anchored to this file's location so the cache is always found regardless +# of the working directory the user launches the proxy from. +_CACHE_DIR = pathlib.Path(__file__).parent.parent.parent / "adblock_cache" # Patterns used during line parsing _IP_RE = re.compile( diff --git a/src/proxy/mitm.py b/src/proxy/mitm.py index ec04194..66637cd 100644 --- a/src/proxy/mitm.py +++ b/src/proxy/mitm.py @@ -137,6 +137,14 @@ class MITMCertManager: f.write(cert_pem + ca_pem) with open(key_file, "wb") as f: f.write(key_pem) + # Restrict private key to current user only on POSIX. + # os.chmod is effectively a no-op on Windows (NTFS ACLs govern + # access there), but the temp directory is already user-scoped. + if os.name == "posix": + try: + os.chmod(key_file, 0o600) + except OSError: + pass ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) ctx.set_alpn_protocols(["http/1.1"]) diff --git a/src/proxy/proxy_server.py b/src/proxy/proxy_server.py index b5e607d..fe2dbce 100644 --- a/src/proxy/proxy_server.py +++ b/src/proxy/proxy_server.py @@ -866,21 +866,14 @@ class ProxyServer: return writer._transport = new_transport - # Step 2: open outgoing TLS to target IP with the safe SNI - ssl_ctx_client = ssl.create_default_context() - if certifi is not None: - try: - ssl_ctx_client.load_verify_locations(cafile=certifi.where()) - except Exception: - pass - if not self.fronter.verify_ssl: - ssl_ctx_client.check_hostname = False - ssl_ctx_client.verify_mode = ssl.CERT_NONE + # Step 2: open outgoing TLS to target IP with the safe SNI. + # Reuse the SSLContext already built by DomainFronter (certifi bundle, + # verify_ssl flag) — no need to rebuild it on every CONNECT. try: r_out, w_out = await asyncio.wait_for( asyncio.open_connection( target_ip, port, - ssl=ssl_ctx_client, + ssl=self.fronter._ssl_ctx(), server_hostname=sni_out, ), timeout=self._tcp_connect_timeout, @@ -1160,12 +1153,8 @@ class ProxyServer: """ if method == "GET" and not body: # Respect client's own Range header verbatim. - if headers: - for k in headers: - if k.lower() == "range": - return await self.fronter.relay( - method, url, headers, body - ) + if header_value(headers, "range"): + return await self.fronter.relay(method, url, headers, body) # Only probe with Range when the URL looks like a big file. if self._is_likely_download(url, headers): return await self.fronter.relay_parallel( @@ -1198,10 +1187,8 @@ class ProxyServer: writer) -> bool: if method.upper() != "GET" or body: return False - if headers: - for key in headers: - if key.lower() == "range": - return False + if header_value(headers, "range"): + return False effective_headers = headers or {} if not self._is_likely_download(url, effective_headers): return False diff --git a/src/proxy/proxy_support.py b/src/proxy/proxy_support.py index 32e7fb5..d32e6a9 100644 --- a/src/proxy/proxy_support.py +++ b/src/proxy/proxy_support.py @@ -230,7 +230,12 @@ class ResponseCache: if b"HTTP/1.1 200" not in raw_response[:20]: return 0 - if "no-store" in hdr or "private" in hdr or "set-cookie:" in hdr: + # Scope no-store / private checks to the Cache-Control header line so + # URLs like "Location: /api/private/…" or "Server: private-build" + # don't accidentally suppress caching for cacheable responses. + if re.search(r"cache-control:[^\r\n]*\b(?:no-store|private)\b", hdr): + return 0 + if "set-cookie:" in hdr: return 0 max_age_match = re.search(r"max-age=(\d+)", hdr) diff --git a/src/relay/domain_fronter.py b/src/relay/domain_fronter.py index 1dd51b5..7621a84 100644 --- a/src/relay/domain_fronter.py +++ b/src/relay/domain_fronter.py @@ -2182,6 +2182,50 @@ class DomainFronter: return parse_relay_response(body, self._max_response_body_bytes) + async def _follow_redirects( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + status: int, + resp_headers: dict, + resp_body: bytes, + original_body: bytes, + ) -> tuple[int, dict, bytes]: + """Follow up to 5 HTTP redirects on an existing H1 connection. + + 307/308 preserve the request method and body; all others become + GET with an empty body (RFC 7231 §6.4). + """ + for _ in range(5): + if status not in (301, 302, 303, 307, 308): + break + location = resp_headers.get("location") + if not location: + break + parsed = urlparse(location) + rpath = parsed.path + ("?" + parsed.query if parsed.query else "") + if status in (307, 308): + redirect_method = "POST" + redirect_body = original_body + else: + redirect_method = "GET" + redirect_body = b"" + request_lines = [ + f"{redirect_method} {rpath} HTTP/1.1", + f"Host: {parsed.netloc}", + "Accept-Encoding: gzip", + "Connection: keep-alive", + ] + if redirect_body: + request_lines.append(f"Content-Length: {len(redirect_body)}") + request = "\r\n".join(request_lines) + "\r\n\r\n" + writer.write(request.encode() + redirect_body) + await writer.drain() + status, resp_headers, resp_body = await read_http_response( + reader, max_bytes=self._max_response_body_bytes + ) + return status, resp_headers, resp_body + async def _relay_single(self, payload: dict) -> bytes: """Execute a single relay POST → redirect → parse.""" # Add auth key @@ -2207,36 +2251,12 @@ class DomainFronter: await writer.drain() self._record_execution(sid) - status, resp_headers, resp_body = await read_http_response(reader, max_bytes=self._max_response_body_bytes) - - # Follow redirect chain on the SAME connection - for _ in range(5): - if status not in (301, 302, 303, 307, 308): - break - location = resp_headers.get("location") - if not location: - break - - parsed = urlparse(location) - rpath = parsed.path + ("?" + parsed.query if parsed.query else "") - if status in (307, 308): - redirect_method = "POST" - redirect_body = json_body - else: - redirect_method = "GET" - redirect_body = b"" - request_lines = [ - f"{redirect_method} {rpath} HTTP/1.1", - f"Host: {parsed.netloc}", - "Accept-Encoding: gzip", - "Connection: keep-alive", - ] - if redirect_body: - request_lines.append(f"Content-Length: {len(redirect_body)}") - request = "\r\n".join(request_lines) + "\r\n\r\n" - writer.write(request.encode() + redirect_body) - await writer.drain() - status, resp_headers, resp_body = await read_http_response(reader, max_bytes=self._max_response_body_bytes) + status, resp_headers, resp_body = await read_http_response( + reader, max_bytes=self._max_response_body_bytes + ) + status, resp_headers, resp_body = await self._follow_redirects( + reader, writer, status, resp_headers, resp_body, json_body + ) await self._release(reader, writer, created) return parse_relay_response(resp_body, self._max_response_body_bytes) @@ -2295,35 +2315,12 @@ class DomainFronter: await writer.drain() self._record_execution(sid) - status, resp_headers, resp_body = await read_http_response(reader, max_bytes=self._max_response_body_bytes) - - # Follow redirects - for _ in range(5): - if status not in (301, 302, 303, 307, 308): - break - location = resp_headers.get("location") - if not location: - break - parsed = urlparse(location) - rpath = parsed.path + ("?" + parsed.query if parsed.query else "") - if status in (307, 308): - redirect_method = "POST" - redirect_body = json_body - else: - redirect_method = "GET" - redirect_body = b"" - request_lines = [ - f"{redirect_method} {rpath} HTTP/1.1", - f"Host: {parsed.netloc}", - "Accept-Encoding: gzip", - "Connection: keep-alive", - ] - if redirect_body: - request_lines.append(f"Content-Length: {len(redirect_body)}") - request = "\r\n".join(request_lines) + "\r\n\r\n" - writer.write(request.encode() + redirect_body) - await writer.drain() - status, resp_headers, resp_body = await read_http_response(reader, max_bytes=self._max_response_body_bytes) + status, resp_headers, resp_body = await read_http_response( + reader, max_bytes=self._max_response_body_bytes + ) + status, resp_headers, resp_body = await self._follow_redirects( + reader, writer, status, resp_headers, resp_body, json_body + ) await self._release(reader, writer, created)