mirror of
https://github.com/denuitt1/mhr-cfw.git
synced 2026-05-17 21:24:36 +03:00
add src/cert_installer.py
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -1,419 +0,0 @@
|
||||
"""
|
||||
HTTP/2 multiplexed transport for domain-fronted connections.
|
||||
|
||||
One TLS connection → many concurrent HTTP/2 streams → massive throughput.
|
||||
Eliminates per-request TLS handshake overhead entirely.
|
||||
|
||||
Instead of a pool of 30 HTTP/1.1 connections (each handling 1 request),
|
||||
this uses a SINGLE HTTP/2 connection handling 100+ concurrent requests.
|
||||
|
||||
Performance comparison:
|
||||
HTTP/1.1 pool: 30 connections × 1 request = 30 concurrent requests max
|
||||
HTTP/2 mux: 1 connection × 100 streams = 100 concurrent requests
|
||||
|
||||
Requires: pip install h2
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import gzip
|
||||
import logging
|
||||
import socket
|
||||
import ssl
|
||||
from urllib.parse import urlparse
|
||||
|
||||
log = logging.getLogger("H2")
|
||||
|
||||
try:
|
||||
import h2.connection
|
||||
import h2.config
|
||||
import h2.events
|
||||
import h2.settings
|
||||
H2_AVAILABLE = True
|
||||
except ImportError:
|
||||
H2_AVAILABLE = False
|
||||
|
||||
|
||||
class _StreamState:
|
||||
"""State for a single in-flight HTTP/2 stream."""
|
||||
__slots__ = ("status", "headers", "data", "done", "error")
|
||||
|
||||
def __init__(self):
|
||||
self.status = 0
|
||||
self.headers: dict[str, str] = {}
|
||||
self.data = bytearray()
|
||||
self.done = asyncio.Event()
|
||||
self.error: str | None = None
|
||||
|
||||
|
||||
class H2Transport:
|
||||
"""
|
||||
Persistent HTTP/2 connection with automatic stream multiplexing.
|
||||
|
||||
All relay requests share ONE TLS connection. Each request becomes
|
||||
an independent HTTP/2 stream, running fully concurrently.
|
||||
|
||||
Features:
|
||||
- Auto-connect on first use
|
||||
- Auto-reconnect on connection loss
|
||||
- Redirect following (as new streams, same connection)
|
||||
- Gzip decompression
|
||||
- Configurable max concurrency
|
||||
"""
|
||||
|
||||
def __init__(self, connect_host: str, sni_host: str,
|
||||
verify_ssl: bool = True):
|
||||
self.connect_host = connect_host
|
||||
self.sni_host = sni_host
|
||||
self.verify_ssl = verify_ssl
|
||||
|
||||
self._reader: asyncio.StreamReader | None = None
|
||||
self._writer: asyncio.StreamWriter | None = None
|
||||
self._h2: "h2.connection.H2Connection | None" = None
|
||||
self._connected = False
|
||||
|
||||
self._write_lock = asyncio.Lock()
|
||||
self._connect_lock = asyncio.Lock()
|
||||
self._read_task: asyncio.Task | None = None
|
||||
|
||||
# Per-stream tracking
|
||||
self._streams: dict[int, _StreamState] = {}
|
||||
|
||||
# Stats
|
||||
self.total_requests = 0
|
||||
self.total_streams = 0
|
||||
|
||||
# ── Connection lifecycle ──────────────────────────────────────
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self._connected
|
||||
|
||||
async def ensure_connected(self):
|
||||
"""Connect if not already connected."""
|
||||
if self._connected:
|
||||
return
|
||||
async with self._connect_lock:
|
||||
if self._connected:
|
||||
return
|
||||
await self._do_connect()
|
||||
|
||||
async def _do_connect(self):
|
||||
"""Establish the HTTP/2 connection with optimized socket settings."""
|
||||
ctx = ssl.create_default_context()
|
||||
# Advertise both h2 and http/1.1 — some DPI blocks h2-only ALPN
|
||||
ctx.set_alpn_protocols(["h2", "http/1.1"])
|
||||
if not self.verify_ssl:
|
||||
ctx.check_hostname = False
|
||||
ctx.verify_mode = ssl.CERT_NONE
|
||||
|
||||
# Create raw TCP socket with TCP_NODELAY BEFORE TLS handshake.
|
||||
# Nagle's algorithm can delay small writes (H2 frames) by up to 200ms
|
||||
# waiting to coalesce — TCP_NODELAY forces immediate send.
|
||||
raw = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
raw.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
raw.setblocking(False)
|
||||
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.get_event_loop().sock_connect(
|
||||
raw, (self.connect_host, 443)
|
||||
),
|
||||
timeout=15,
|
||||
)
|
||||
self._reader, self._writer = await asyncio.wait_for(
|
||||
asyncio.open_connection(
|
||||
ssl=ctx,
|
||||
server_hostname=self.sni_host,
|
||||
sock=raw,
|
||||
),
|
||||
timeout=15,
|
||||
)
|
||||
except Exception:
|
||||
raw.close()
|
||||
raise
|
||||
|
||||
# Verify we actually got HTTP/2
|
||||
ssl_obj = self._writer.get_extra_info("ssl_object")
|
||||
negotiated = ssl_obj.selected_alpn_protocol() if ssl_obj else None
|
||||
if negotiated != "h2":
|
||||
self._writer.close()
|
||||
raise RuntimeError(
|
||||
f"H2 ALPN negotiation failed (got {negotiated!r})"
|
||||
)
|
||||
|
||||
config = h2.config.H2Configuration(
|
||||
client_side=True,
|
||||
header_encoding="utf-8",
|
||||
)
|
||||
self._h2 = h2.connection.H2Connection(config=config)
|
||||
self._h2.initiate_connection()
|
||||
|
||||
# Connection-level flow control: ~16MB window
|
||||
self._h2.increment_flow_control_window(2 ** 24 - 65535)
|
||||
|
||||
# Per-stream settings: 1MB initial window, disable server push
|
||||
self._h2.update_settings({
|
||||
h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 1 * 1024 * 1024,
|
||||
h2.settings.SettingCodes.ENABLE_PUSH: 0,
|
||||
})
|
||||
|
||||
await self._flush()
|
||||
|
||||
self._connected = True
|
||||
self._read_task = asyncio.create_task(self._reader_loop())
|
||||
log.info("H2 connected → %s (SNI=%s, TCP_NODELAY=on)",
|
||||
self.connect_host, self.sni_host)
|
||||
|
||||
async def reconnect(self):
|
||||
"""Close current connection and re-establish."""
|
||||
await self._close_internal()
|
||||
await self._do_connect()
|
||||
|
||||
async def _close_internal(self):
|
||||
self._connected = False
|
||||
if self._read_task:
|
||||
self._read_task.cancel()
|
||||
self._read_task = None
|
||||
if self._writer:
|
||||
try:
|
||||
self._writer.close()
|
||||
except Exception:
|
||||
pass
|
||||
self._writer = None
|
||||
# Wake all pending streams so they can raise
|
||||
for state in self._streams.values():
|
||||
state.error = "Connection closed"
|
||||
state.done.set()
|
||||
self._streams.clear()
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────
|
||||
|
||||
async def request(self, method: str, path: str, host: str,
|
||||
headers: dict | None = None,
|
||||
body: bytes | None = None,
|
||||
timeout: float = 25,
|
||||
follow_redirects: int = 5) -> tuple[int, dict, bytes]:
|
||||
"""
|
||||
Send an HTTP/2 request and return (status, headers, body).
|
||||
|
||||
Thread-safe: many concurrent calls each get their own stream.
|
||||
Redirects are followed as new streams on the same connection.
|
||||
"""
|
||||
await self.ensure_connected()
|
||||
self.total_requests += 1
|
||||
|
||||
for _ in range(follow_redirects + 1):
|
||||
status, resp_headers, resp_body = await self._single_request(
|
||||
method, path, host, headers, body, timeout,
|
||||
)
|
||||
|
||||
if status not in (301, 302, 303, 307, 308):
|
||||
return status, resp_headers, resp_body
|
||||
|
||||
location = resp_headers.get("location", "")
|
||||
if not location:
|
||||
return status, resp_headers, resp_body
|
||||
|
||||
parsed = urlparse(location)
|
||||
path = parsed.path + ("?" + parsed.query if parsed.query else "")
|
||||
host = parsed.netloc or host
|
||||
method = "GET"
|
||||
body = None
|
||||
headers = None # Drop request headers on redirect
|
||||
|
||||
return status, resp_headers, resp_body
|
||||
|
||||
# ── Stream handling ───────────────────────────────────────────
|
||||
|
||||
async def _single_request(self, method, path, host, headers, body,
|
||||
timeout) -> tuple[int, dict, bytes]:
|
||||
"""Send one HTTP/2 request on a new stream, wait for response."""
|
||||
if not self._connected:
|
||||
await self.ensure_connected()
|
||||
|
||||
stream_id = None
|
||||
|
||||
async with self._write_lock:
|
||||
try:
|
||||
stream_id = self._h2.get_next_available_stream_id()
|
||||
except Exception:
|
||||
# Connection is stale — reconnect
|
||||
await self.reconnect()
|
||||
stream_id = self._h2.get_next_available_stream_id()
|
||||
|
||||
h2_headers = [
|
||||
(":method", method),
|
||||
(":path", path),
|
||||
(":authority", host),
|
||||
(":scheme", "https"),
|
||||
("accept-encoding", "gzip"),
|
||||
]
|
||||
if headers:
|
||||
for k, v in headers.items():
|
||||
h2_headers.append((k.lower(), str(v)))
|
||||
|
||||
end_stream = not body
|
||||
self._h2.send_headers(stream_id, h2_headers, end_stream=end_stream)
|
||||
|
||||
if body:
|
||||
# Send body (may need chunking for flow control)
|
||||
self._send_body(stream_id, body)
|
||||
|
||||
state = _StreamState()
|
||||
self._streams[stream_id] = state
|
||||
self.total_streams += 1
|
||||
|
||||
await self._flush()
|
||||
|
||||
# Wait for complete response
|
||||
try:
|
||||
await asyncio.wait_for(state.done.wait(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self._streams.pop(stream_id, None)
|
||||
raise TimeoutError(
|
||||
f"H2 stream {stream_id} timed out ({timeout}s)"
|
||||
)
|
||||
|
||||
self._streams.pop(stream_id, None)
|
||||
|
||||
if state.error:
|
||||
raise ConnectionError(f"H2 stream error: {state.error}")
|
||||
|
||||
# Auto-decompress gzip
|
||||
resp_body = bytes(state.data)
|
||||
if state.headers.get("content-encoding", "").lower() == "gzip":
|
||||
try:
|
||||
resp_body = gzip.decompress(resp_body)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return state.status, state.headers, resp_body
|
||||
|
||||
def _send_body(self, stream_id: int, body: bytes):
|
||||
"""Send request body, respecting H2 flow control window."""
|
||||
# For small bodies (typical JSON payloads), send in one shot
|
||||
while body:
|
||||
max_size = self._h2.local_settings.max_frame_size
|
||||
window = self._h2.local_flow_control_window(stream_id)
|
||||
send_size = min(len(body), max_size, window)
|
||||
if send_size <= 0:
|
||||
# Flow control full — let the reader loop process
|
||||
# window updates before we continue
|
||||
break
|
||||
end = send_size >= len(body)
|
||||
self._h2.send_data(stream_id, body[:send_size], end_stream=end)
|
||||
body = body[send_size:]
|
||||
|
||||
# ── Background reader ─────────────────────────────────────────
|
||||
|
||||
async def _reader_loop(self):
|
||||
"""Background: read H2 frames, dispatch events to waiting streams."""
|
||||
try:
|
||||
while self._connected:
|
||||
data = await self._reader.read(65536)
|
||||
if not data:
|
||||
log.warning("H2 remote closed connection")
|
||||
break
|
||||
|
||||
try:
|
||||
events = self._h2.receive_data(data)
|
||||
except Exception as e:
|
||||
log.error("H2 protocol error: %s", e)
|
||||
break
|
||||
|
||||
for event in events:
|
||||
self._dispatch(event)
|
||||
|
||||
# Send pending data (acks, window updates, ping responses)
|
||||
async with self._write_lock:
|
||||
await self._flush()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
log.error("H2 reader error: %s", e)
|
||||
finally:
|
||||
self._connected = False
|
||||
for state in self._streams.values():
|
||||
if not state.done.is_set():
|
||||
state.error = "Connection lost"
|
||||
state.done.set()
|
||||
log.info("H2 reader loop ended")
|
||||
|
||||
def _dispatch(self, event):
|
||||
"""Route a single h2 event to its stream."""
|
||||
if isinstance(event, h2.events.ResponseReceived):
|
||||
state = self._streams.get(event.stream_id)
|
||||
if state:
|
||||
for name, value in event.headers:
|
||||
n = name if isinstance(name, str) else name.decode()
|
||||
v = value if isinstance(value, str) else value.decode()
|
||||
if n == ":status":
|
||||
state.status = int(v)
|
||||
else:
|
||||
state.headers[n] = v
|
||||
|
||||
elif isinstance(event, h2.events.DataReceived):
|
||||
state = self._streams.get(event.stream_id)
|
||||
if state:
|
||||
state.data.extend(event.data)
|
||||
# Always acknowledge received data for flow control
|
||||
self._h2.acknowledge_received_data(
|
||||
event.flow_controlled_length, event.stream_id
|
||||
)
|
||||
|
||||
elif isinstance(event, h2.events.StreamEnded):
|
||||
state = self._streams.get(event.stream_id)
|
||||
if state:
|
||||
state.done.set()
|
||||
|
||||
elif isinstance(event, h2.events.StreamReset):
|
||||
state = self._streams.get(event.stream_id)
|
||||
if state:
|
||||
state.error = f"Stream reset (code={event.error_code})"
|
||||
state.done.set()
|
||||
|
||||
elif isinstance(event, h2.events.WindowUpdated):
|
||||
pass # h2 library handles window bookkeeping
|
||||
|
||||
elif isinstance(event, h2.events.SettingsAcknowledged):
|
||||
pass
|
||||
|
||||
elif isinstance(event, h2.events.PingReceived):
|
||||
pass # h2 library auto-responds
|
||||
|
||||
elif isinstance(event, h2.events.PingAckReceived):
|
||||
pass # keepalive confirmed
|
||||
|
||||
# ── Internal ──────────────────────────────────────────────────
|
||||
|
||||
async def _flush(self):
|
||||
"""Write pending H2 frame data to the socket."""
|
||||
data = self._h2.data_to_send()
|
||||
if data and self._writer:
|
||||
self._writer.write(data)
|
||||
await self._writer.drain()
|
||||
|
||||
async def close(self):
|
||||
"""Gracefully close the HTTP/2 connection."""
|
||||
if self._h2 and self._connected:
|
||||
try:
|
||||
self._h2.close_connection()
|
||||
async with self._write_lock:
|
||||
await self._flush()
|
||||
except Exception:
|
||||
pass
|
||||
await self._close_internal()
|
||||
|
||||
async def ping(self):
|
||||
"""Send an H2 PING frame to keep the connection alive."""
|
||||
if not self._connected or not self._h2:
|
||||
return
|
||||
try:
|
||||
async with self._write_lock:
|
||||
if not self._connected:
|
||||
return
|
||||
self._h2.ping(b"\x00" * 8)
|
||||
await self._flush()
|
||||
except Exception as e:
|
||||
log.debug("H2 PING failed: %s", e)
|
||||
-153
@@ -1,153 +0,0 @@
|
||||
"""
|
||||
MITM certificate manager for HTTPS interception.
|
||||
|
||||
Generates a CA certificate (once, stored as files) and per-domain
|
||||
certificates (on the fly, cached in memory) so the local proxy can
|
||||
decrypt HTTPS traffic and relay it through Apps Script.
|
||||
|
||||
The user must install ca/ca.crt in their browser's trusted CAs once.
|
||||
|
||||
Requires: pip install cryptography
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import tempfile
|
||||
|
||||
from cryptography import x509
|
||||
from cryptography.hazmat.primitives import hashes, serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.x509.oid import NameOID
|
||||
|
||||
log = logging.getLogger("MITM")
|
||||
|
||||
CA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../cert")
|
||||
CA_KEY_FILE = os.path.join(CA_DIR, "ca.key")
|
||||
CA_CERT_FILE = os.path.join(CA_DIR, "ca.crt")
|
||||
|
||||
|
||||
class MITMCertManager:
|
||||
def __init__(self):
|
||||
self._ca_key = None
|
||||
self._ca_cert = None
|
||||
self._ctx_cache: dict[str, ssl.SSLContext] = {}
|
||||
self._cert_dir = tempfile.mkdtemp(prefix="domainfront_certs_")
|
||||
self._ensure_ca()
|
||||
|
||||
def _ensure_ca(self):
|
||||
if os.path.exists(CA_KEY_FILE) and os.path.exists(CA_CERT_FILE):
|
||||
with open(CA_KEY_FILE, "rb") as f:
|
||||
self._ca_key = serialization.load_pem_private_key(
|
||||
f.read(), password=None
|
||||
)
|
||||
with open(CA_CERT_FILE, "rb") as f:
|
||||
self._ca_cert = x509.load_pem_x509_certificate(f.read())
|
||||
log.info("Loaded CA from %s", CA_DIR)
|
||||
else:
|
||||
self._create_ca()
|
||||
|
||||
def _create_ca(self):
|
||||
os.makedirs(CA_DIR, exist_ok=True)
|
||||
|
||||
self._ca_key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048
|
||||
)
|
||||
subject = issuer = x509.Name([
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, "MHR_CFW"),
|
||||
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "MHR_CFW"),
|
||||
])
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
self._ca_cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(issuer)
|
||||
.public_key(self._ca_key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(now + datetime.timedelta(days=3650))
|
||||
.add_extension(
|
||||
x509.BasicConstraints(ca=True, path_length=0), critical=True
|
||||
)
|
||||
.add_extension(
|
||||
x509.KeyUsage(
|
||||
digital_signature=True,
|
||||
key_cert_sign=True,
|
||||
crl_sign=True,
|
||||
content_commitment=False,
|
||||
key_encipherment=False,
|
||||
data_encipherment=False,
|
||||
key_agreement=False,
|
||||
encipher_only=False,
|
||||
decipher_only=False,
|
||||
),
|
||||
critical=True,
|
||||
)
|
||||
.sign(self._ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
with open(CA_KEY_FILE, "wb") as f:
|
||||
f.write(
|
||||
self._ca_key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
serialization.NoEncryption(),
|
||||
)
|
||||
)
|
||||
with open(CA_CERT_FILE, "wb") as f:
|
||||
f.write(self._ca_cert.public_bytes(serialization.Encoding.PEM))
|
||||
|
||||
log.warning("Generated new CA certificate: %s", CA_CERT_FILE)
|
||||
log.warning(">>> Install this file in your browser's Trusted Root CAs! <<<")
|
||||
|
||||
def get_server_context(self, domain: str) -> ssl.SSLContext:
|
||||
if domain not in self._ctx_cache:
|
||||
key_pem, cert_pem = self._generate_domain_cert(domain)
|
||||
|
||||
cert_file = os.path.join(self._cert_dir, f"{domain}.crt")
|
||||
key_file = os.path.join(self._cert_dir, f"{domain}.key")
|
||||
|
||||
ca_pem = self._ca_cert.public_bytes(serialization.Encoding.PEM)
|
||||
with open(cert_file, "wb") as f:
|
||||
f.write(cert_pem + ca_pem)
|
||||
with open(key_file, "wb") as f:
|
||||
f.write(key_pem)
|
||||
|
||||
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
ctx.set_alpn_protocols(["http/1.1"])
|
||||
ctx.load_cert_chain(cert_file, key_file)
|
||||
self._ctx_cache[domain] = ctx
|
||||
|
||||
return self._ctx_cache[domain]
|
||||
|
||||
def _generate_domain_cert(self, domain: str):
|
||||
key = rsa.generate_private_key(
|
||||
public_exponent=65537, key_size=2048
|
||||
)
|
||||
subject = x509.Name([
|
||||
x509.NameAttribute(NameOID.COMMON_NAME, domain),
|
||||
])
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
cert = (
|
||||
x509.CertificateBuilder()
|
||||
.subject_name(subject)
|
||||
.issuer_name(self._ca_cert.subject)
|
||||
.public_key(key.public_key())
|
||||
.serial_number(x509.random_serial_number())
|
||||
.not_valid_before(now)
|
||||
.not_valid_after(now + datetime.timedelta(days=365))
|
||||
.add_extension(
|
||||
x509.SubjectAlternativeName([x509.DNSName(domain)]),
|
||||
critical=False,
|
||||
)
|
||||
.sign(self._ca_key, hashes.SHA256())
|
||||
)
|
||||
|
||||
key_pem = key.private_bytes(
|
||||
serialization.Encoding.PEM,
|
||||
serialization.PrivateFormat.TraditionalOpenSSL,
|
||||
serialization.NoEncryption(),
|
||||
)
|
||||
cert_pem = cert.public_bytes(serialization.Encoding.PEM)
|
||||
return key_pem, cert_pem
|
||||
@@ -1,777 +0,0 @@
|
||||
"""
|
||||
Local HTTP proxy server.
|
||||
|
||||
Intercepts the user's browser traffic and forwards everything through
|
||||
a domain-fronted connection to a CDN worker or Apps Script relay.
|
||||
|
||||
Supports:
|
||||
- CONNECT method → WebSocket tunnel (modes 1-3) or MITM relay (apps_script)
|
||||
- GET / POST etc. → HTTP forwarding (modes 1-3) or JSON relay (apps_script)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import re
|
||||
import ssl
|
||||
import time
|
||||
|
||||
from core.domain_fronter import DomainFronter
|
||||
|
||||
log = logging.getLogger("Proxy")
|
||||
|
||||
|
||||
class ResponseCache:
|
||||
"""Simple LRU response cache — avoids repeated relay calls."""
|
||||
|
||||
def __init__(self, max_mb: int = 50):
|
||||
self._store: dict[str, tuple[bytes, float]] = {}
|
||||
self._size = 0
|
||||
self._max = max_mb * 1024 * 1024
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
|
||||
def get(self, url: str) -> bytes | None:
|
||||
entry = self._store.get(url)
|
||||
if not entry:
|
||||
self.misses += 1
|
||||
return None
|
||||
raw, expires = entry
|
||||
if time.time() > expires:
|
||||
self._size -= len(raw)
|
||||
del self._store[url]
|
||||
self.misses += 1
|
||||
return None
|
||||
self.hits += 1
|
||||
return raw
|
||||
|
||||
def put(self, url: str, raw_response: bytes, ttl: int = 300):
|
||||
size = len(raw_response)
|
||||
if size > self._max // 4 or size == 0:
|
||||
return
|
||||
# Evict oldest to make room
|
||||
while self._size + size > self._max and self._store:
|
||||
oldest = next(iter(self._store))
|
||||
self._size -= len(self._store[oldest][0])
|
||||
del self._store[oldest]
|
||||
if url in self._store:
|
||||
self._size -= len(self._store[url][0])
|
||||
self._store[url] = (raw_response, time.time() + ttl)
|
||||
self._size += size
|
||||
|
||||
@staticmethod
|
||||
def parse_ttl(raw_response: bytes, url: str) -> int:
|
||||
"""Determine cache TTL from response headers and URL."""
|
||||
hdr_end = raw_response.find(b"\r\n\r\n")
|
||||
if hdr_end < 0:
|
||||
return 0
|
||||
hdr = raw_response[:hdr_end].decode(errors="replace").lower()
|
||||
|
||||
# Don't cache errors or non-200
|
||||
if b"HTTP/1.1 200" not in raw_response[:20]:
|
||||
return 0
|
||||
if "no-store" in hdr:
|
||||
return 0
|
||||
|
||||
# Explicit max-age
|
||||
m = re.search(r"max-age=(\d+)", hdr)
|
||||
if m:
|
||||
return min(int(m.group(1)), 86400)
|
||||
|
||||
# Heuristic by content type / extension
|
||||
path = url.split("?")[0].lower()
|
||||
static_exts = (
|
||||
".css", ".js", ".woff", ".woff2", ".ttf", ".eot",
|
||||
".png", ".jpg", ".jpeg", ".gif", ".webp", ".svg", ".ico",
|
||||
".mp3", ".mp4", ".wasm",
|
||||
)
|
||||
for ext in static_exts:
|
||||
if path.endswith(ext):
|
||||
return 3600 # 1 hour for static assets
|
||||
|
||||
ct_m = re.search(r"content-type:\s*([^\r\n]+)", hdr)
|
||||
ct = ct_m.group(1) if ct_m else ""
|
||||
if "image/" in ct or "font/" in ct:
|
||||
return 3600
|
||||
if "text/css" in ct or "javascript" in ct:
|
||||
return 1800
|
||||
if "text/html" in ct or "application/json" in ct:
|
||||
return 0 # don't cache dynamic content by default
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
class ProxyServer:
|
||||
def __init__(self, config: dict):
|
||||
self.host = config.get("listen_host", "127.0.0.1")
|
||||
self.port = config.get("listen_port", 8080)
|
||||
self.mode = config.get("mode", "domain_fronting")
|
||||
self.fronter = DomainFronter(config)
|
||||
self.mitm = None
|
||||
self._cache = ResponseCache(max_mb=50)
|
||||
|
||||
# Persistent HTTP tunnel cache for google_fronting mode
|
||||
# Key: "host:port" → (tunnel_reader, tunnel_writer, lock)
|
||||
self._http_tunnels: dict = {}
|
||||
self._tunnel_lock = asyncio.Lock()
|
||||
|
||||
# hosts override — DNS fake-map: domain/suffix → IP
|
||||
# Checked before any real DNS lookup; supports exact and suffix matching.
|
||||
self._hosts: dict[str, str] = config.get("hosts", {})
|
||||
|
||||
if self.mode == "apps_script":
|
||||
try:
|
||||
from core.mitm import MITMCertManager
|
||||
self.mitm = MITMCertManager()
|
||||
except ImportError:
|
||||
log.error("apps_script mode requires 'cryptography' package.")
|
||||
log.error("Run: pip install cryptography")
|
||||
raise SystemExit(1)
|
||||
|
||||
async def start(self):
|
||||
srv = await asyncio.start_server(self._on_client, self.host, self.port)
|
||||
log.info(
|
||||
"Listening on %s:%d — configure your browser HTTP proxy to this address",
|
||||
self.host, self.port,
|
||||
)
|
||||
async with srv:
|
||||
await srv.serve_forever()
|
||||
|
||||
# ── client handler ────────────────────────────────────────────
|
||||
|
||||
async def _on_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
|
||||
addr = writer.get_extra_info("peername")
|
||||
try:
|
||||
first_line = await asyncio.wait_for(reader.readline(), timeout=30)
|
||||
if not first_line:
|
||||
return
|
||||
|
||||
# Read remaining headers
|
||||
header_block = first_line
|
||||
while True:
|
||||
line = await asyncio.wait_for(reader.readline(), timeout=10)
|
||||
header_block += line
|
||||
if line in (b"\r\n", b"\n", b""):
|
||||
break
|
||||
|
||||
request_line = first_line.decode(errors="replace").strip()
|
||||
parts = request_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
return
|
||||
|
||||
method = parts[0].upper()
|
||||
|
||||
if method == "CONNECT":
|
||||
await self._do_connect(parts[1], reader, writer)
|
||||
else:
|
||||
await self._do_http(header_block, reader, writer)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
log.debug("Timeout: %s", addr)
|
||||
except Exception as e:
|
||||
log.error("Error (%s): %s", addr, e)
|
||||
finally:
|
||||
try:
|
||||
writer.close()
|
||||
await writer.wait_closed()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ── CONNECT (HTTPS tunnelling) ────────────────────────────────
|
||||
|
||||
async def _do_connect(self, target: str, reader, writer):
|
||||
host, _, port = target.rpartition(":")
|
||||
port = int(port) if port else 443
|
||||
if not host:
|
||||
host, port = target, 443
|
||||
|
||||
log.info("CONNECT → %s:%d", host, port)
|
||||
|
||||
writer.write(b"HTTP/1.1 200 Connection Established\r\n\r\n")
|
||||
await writer.drain()
|
||||
|
||||
if self.mode == "apps_script":
|
||||
override_ip = self._sni_rewrite_ip(host)
|
||||
if override_ip:
|
||||
# SNI-blocked domain: MITM-decrypt from browser, then
|
||||
# re-connect to the override IP with SNI=front_domain so
|
||||
# the ISP never sees the blocked hostname in the TLS handshake.
|
||||
log.info("SNI-rewrite tunnel → %s via %s (SNI: %s)",
|
||||
host, override_ip, self.fronter.sni_host)
|
||||
await self._do_sni_rewrite_tunnel(host, port, reader, writer,
|
||||
connect_ip=override_ip)
|
||||
elif self._is_google_domain(host):
|
||||
log.info("Direct tunnel → %s (Google domain, skipping relay)", host)
|
||||
await self._do_direct_tunnel(host, port, reader, writer)
|
||||
else:
|
||||
await self._do_mitm_connect(host, port, reader, writer)
|
||||
else:
|
||||
await self.fronter.tunnel(host, port, reader, writer)
|
||||
|
||||
# ── Hosts override (fake DNS) ─────────────────────────────────
|
||||
|
||||
# Built-in list of domains that must be reached via Google's frontend IP
|
||||
# with SNI rewritten to `front_domain` (default: www.google.com).
|
||||
# These are Google-owned services whose real SNI is DPI-blocked in some
|
||||
# countries, but that Google serves from the same edge IP as www.google.com.
|
||||
# Users don't need to configure anything — any host matching one of these
|
||||
# suffixes is transparently SNI-rewritten to the configured `google_ip`.
|
||||
# Config's "hosts" map still takes precedence (for custom overrides).
|
||||
_SNI_REWRITE_SUFFIXES = (
|
||||
"youtube.com",
|
||||
"youtu.be",
|
||||
"youtube-nocookie.com",
|
||||
"ytimg.com",
|
||||
"ggpht.com",
|
||||
"gvt1.com",
|
||||
"gvt2.com",
|
||||
"doubleclick.net",
|
||||
"googlesyndication.com",
|
||||
"googleadservices.com",
|
||||
"google-analytics.com",
|
||||
"googletagmanager.com",
|
||||
"googletagservices.com",
|
||||
"fonts.googleapis.com",
|
||||
)
|
||||
|
||||
def _sni_rewrite_ip(self, host: str) -> str | None:
|
||||
"""Return the IP to SNI-rewrite `host` through, or None.
|
||||
|
||||
Order of precedence:
|
||||
1. Explicit entry in config `hosts` map (exact or suffix match).
|
||||
2. Built-in `_SNI_REWRITE_SUFFIXES` → mapped to config `google_ip`.
|
||||
"""
|
||||
ip = self._hosts_ip(host)
|
||||
if ip:
|
||||
return ip
|
||||
h = host.lower().rstrip(".")
|
||||
for suffix in self._SNI_REWRITE_SUFFIXES:
|
||||
if h == suffix or h.endswith("." + suffix):
|
||||
return self.fronter.connect_host # configured google_ip
|
||||
return None
|
||||
|
||||
def _hosts_ip(self, host: str) -> str | None:
|
||||
"""Return override IP for host if defined in config 'hosts', else None.
|
||||
|
||||
Supports exact match and suffix match (e.g. 'youtube.com' matches
|
||||
'www.youtube.com', 'm.youtube.com', etc.).
|
||||
"""
|
||||
h = host.lower().rstrip(".")
|
||||
if h in self._hosts:
|
||||
return self._hosts[h]
|
||||
# suffix match: check every parent label
|
||||
parts = h.split(".")
|
||||
for i in range(1, len(parts)):
|
||||
parent = ".".join(parts[i:])
|
||||
if parent in self._hosts:
|
||||
return self._hosts[parent]
|
||||
return None
|
||||
|
||||
# ── Google domain detection ───────────────────────────────────
|
||||
|
||||
# Only domains whose SNI the ISP does NOT block — direct tunnel is safe.
|
||||
# YouTube/googlevideo SNIs are blocked; they go through _do_sni_rewrite_tunnel
|
||||
# via the hosts map instead.
|
||||
_GOOGLE_SUFFIXES = (
|
||||
".google.com", ".google.co",
|
||||
".googleapis.com", ".gstatic.com",
|
||||
".googleusercontent.com",
|
||||
)
|
||||
_GOOGLE_EXACT = {
|
||||
"google.com", "gstatic.com", "googleapis.com",
|
||||
}
|
||||
|
||||
def _is_google_domain(self, host: str) -> bool:
|
||||
"""Return True if host is a Google-owned domain."""
|
||||
h = host.lower().rstrip(".")
|
||||
if h in self._GOOGLE_EXACT:
|
||||
return True
|
||||
for suffix in self._GOOGLE_SUFFIXES:
|
||||
if h.endswith(suffix):
|
||||
return True
|
||||
return False
|
||||
|
||||
# ── Direct tunnel (no MITM) ───────────────────────────────────
|
||||
|
||||
async def _do_direct_tunnel(self, host: str, port: int,
|
||||
reader: asyncio.StreamReader,
|
||||
writer: asyncio.StreamWriter,
|
||||
connect_ip: str | None = None):
|
||||
"""Pipe raw TLS bytes directly to the target server.
|
||||
|
||||
connect_ip overrides DNS: the TCP connection goes to that IP
|
||||
while the browser's TLS (SNI=host) is piped through unchanged.
|
||||
Defaults to the configured google_ip for Google-category domains.
|
||||
"""
|
||||
target_ip = connect_ip or self.fronter.connect_host
|
||||
try:
|
||||
r_remote, w_remote = await asyncio.wait_for(
|
||||
asyncio.open_connection(target_ip, port), timeout=10
|
||||
)
|
||||
except Exception as e:
|
||||
log.error("Direct tunnel connect failed (%s via %s): %s",
|
||||
host, target_ip, e)
|
||||
return
|
||||
|
||||
async def pipe(src, dst, label):
|
||||
try:
|
||||
while True:
|
||||
data = await src.read(65536)
|
||||
if not data:
|
||||
break
|
||||
dst.write(data)
|
||||
await dst.drain()
|
||||
except (ConnectionError, asyncio.CancelledError):
|
||||
pass
|
||||
except Exception as e:
|
||||
log.debug("Pipe %s ended: %s", label, e)
|
||||
finally:
|
||||
try:
|
||||
dst.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.gather(
|
||||
pipe(reader, w_remote, f"client→{host}"),
|
||||
pipe(r_remote, writer, f"{host}→client"),
|
||||
)
|
||||
|
||||
# ── SNI-rewrite tunnel ────────────────────────────────────────
|
||||
|
||||
async def _do_sni_rewrite_tunnel(self, host: str, port: int, reader, writer,
|
||||
connect_ip: str | None = None):
|
||||
"""MITM-decrypt TLS from browser, then re-encrypt toward connect_ip
|
||||
using SNI=front_domain (e.g. www.google.com).
|
||||
|
||||
The ISP only ever sees SNI=www.google.com in the outgoing handshake,
|
||||
hiding the blocked hostname (e.g. www.youtube.com).
|
||||
"""
|
||||
target_ip = connect_ip or self.fronter.connect_host
|
||||
sni_out = self.fronter.sni_host # e.g. "www.google.com"
|
||||
|
||||
# Step 1: MITM — accept TLS from the browser
|
||||
ssl_ctx_server = self.mitm.get_server_context(host)
|
||||
loop = asyncio.get_event_loop()
|
||||
transport = writer.transport
|
||||
protocol = transport.get_protocol()
|
||||
try:
|
||||
new_transport = await loop.start_tls(
|
||||
transport, protocol, ssl_ctx_server, server_side=True,
|
||||
)
|
||||
except Exception as e:
|
||||
log.debug("SNI-rewrite TLS accept failed (%s): %s", host, e)
|
||||
return
|
||||
writer._transport = new_transport
|
||||
|
||||
# Step 2: open outgoing TLS to target IP with the safe SNI
|
||||
ssl_ctx_client = ssl.create_default_context()
|
||||
if not self.fronter.verify_ssl:
|
||||
ssl_ctx_client.check_hostname = False
|
||||
ssl_ctx_client.verify_mode = ssl.CERT_NONE
|
||||
try:
|
||||
r_out, w_out = await asyncio.wait_for(
|
||||
asyncio.open_connection(
|
||||
target_ip, port,
|
||||
ssl=ssl_ctx_client,
|
||||
server_hostname=sni_out,
|
||||
),
|
||||
timeout=10,
|
||||
)
|
||||
except Exception as e:
|
||||
log.error("SNI-rewrite outbound connect failed (%s via %s): %s",
|
||||
host, target_ip, e)
|
||||
return
|
||||
|
||||
# Step 3: pipe application-layer bytes between the two TLS sessions
|
||||
async def pipe(src, dst, label):
|
||||
try:
|
||||
while True:
|
||||
data = await src.read(65536)
|
||||
if not data:
|
||||
break
|
||||
dst.write(data)
|
||||
await dst.drain()
|
||||
except (ConnectionError, asyncio.CancelledError):
|
||||
pass
|
||||
except Exception as exc:
|
||||
log.debug("Pipe %s ended: %s", label, exc)
|
||||
finally:
|
||||
try:
|
||||
dst.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
await asyncio.gather(
|
||||
pipe(reader, w_out, f"client→{host}"),
|
||||
pipe(r_out, writer, f"{host}→client"),
|
||||
)
|
||||
|
||||
# ── MITM CONNECT (apps_script mode) ───────────────────────────
|
||||
|
||||
async def _do_mitm_connect(self, host: str, port: int, reader, writer):
|
||||
"""Intercept TLS, decrypt HTTP, and relay through Apps Script."""
|
||||
ssl_ctx = self.mitm.get_server_context(host)
|
||||
|
||||
# Upgrade the existing connection to TLS (we are the server)
|
||||
loop = asyncio.get_event_loop()
|
||||
transport = writer.transport
|
||||
protocol = transport.get_protocol()
|
||||
|
||||
try:
|
||||
new_transport = await loop.start_tls(
|
||||
transport, protocol, ssl_ctx, server_side=True,
|
||||
)
|
||||
except Exception as e:
|
||||
# Non-HTTPS traffic (e.g. MTProto, plain HTTP on port 80/443)
|
||||
# routed through the proxy will always fail TLS — log at DEBUG
|
||||
# to avoid alarming noise.
|
||||
if port != 443:
|
||||
log.debug("TLS handshake skipped for %s:%d (non-HTTPS): %s", host, port, e)
|
||||
else:
|
||||
log.debug("TLS handshake failed for %s: %s", host, e)
|
||||
return
|
||||
|
||||
# Update writer to use the new TLS transport
|
||||
writer._transport = new_transport
|
||||
|
||||
# Read and relay HTTP requests from the browser (now decrypted)
|
||||
while True:
|
||||
try:
|
||||
first_line = await asyncio.wait_for(reader.readline(), timeout=120)
|
||||
if not first_line:
|
||||
break
|
||||
|
||||
header_block = first_line
|
||||
while True:
|
||||
line = await asyncio.wait_for(reader.readline(), timeout=10)
|
||||
header_block += line
|
||||
if line in (b"\r\n", b"\n", b""):
|
||||
break
|
||||
|
||||
# Read body
|
||||
body = b""
|
||||
for raw_line in header_block.split(b"\r\n"):
|
||||
if raw_line.lower().startswith(b"content-length:"):
|
||||
length = int(raw_line.split(b":", 1)[1].strip())
|
||||
body = await reader.readexactly(length)
|
||||
break
|
||||
|
||||
# Parse the request
|
||||
request_line = first_line.decode(errors="replace").strip()
|
||||
parts = request_line.split(" ", 2)
|
||||
if len(parts) < 2:
|
||||
break
|
||||
|
||||
method = parts[0]
|
||||
path = parts[1]
|
||||
|
||||
# Parse headers
|
||||
headers = {}
|
||||
for raw_line in header_block.split(b"\r\n")[1:]:
|
||||
if b":" in raw_line:
|
||||
k, v = raw_line.decode(errors="replace").split(":", 1)
|
||||
headers[k.strip()] = v.strip()
|
||||
|
||||
# Build full URL (browser sends just the path in CONNECT)
|
||||
if port == 443:
|
||||
url = f"https://{host}{path}"
|
||||
else:
|
||||
url = f"https://{host}:{port}{path}"
|
||||
|
||||
log.info("MITM → %s %s", method, url)
|
||||
|
||||
# ── CORS: extract relevant request headers ────────────────────
|
||||
origin = next(
|
||||
(v for k, v in headers.items() if k.lower() == "origin"), ""
|
||||
)
|
||||
acr_method = next(
|
||||
(v for k, v in headers.items()
|
||||
if k.lower() == "access-control-request-method"), ""
|
||||
)
|
||||
acr_headers = next(
|
||||
(v for k, v in headers.items()
|
||||
if k.lower() == "access-control-request-headers"), ""
|
||||
)
|
||||
|
||||
# CORS preflight — respond directly; UrlFetchApp doesn't
|
||||
# support OPTIONS so forwarding it would always fail.
|
||||
if method.upper() == "OPTIONS" and acr_method:
|
||||
log.debug("CORS preflight → %s (responding locally)", url[:60])
|
||||
writer.write(self._cors_preflight_response(origin, acr_method, acr_headers))
|
||||
await writer.drain()
|
||||
continue
|
||||
|
||||
# Check local cache first (GET only)
|
||||
response = None
|
||||
if method == "GET" and not body:
|
||||
response = self._cache.get(url)
|
||||
if response:
|
||||
log.debug("Cache HIT: %s", url[:60])
|
||||
|
||||
if response is None:
|
||||
# Relay through Apps Script
|
||||
try:
|
||||
response = await self._relay_smart(method, url, headers, body)
|
||||
except Exception as e:
|
||||
log.error("Relay error (%s): %s", url[:60], e)
|
||||
err_body = f"Relay error: {e}".encode()
|
||||
response = (
|
||||
b"HTTP/1.1 502 Bad Gateway\r\n"
|
||||
b"Content-Type: text/plain\r\n"
|
||||
b"Content-Length: " + str(len(err_body)).encode() + b"\r\n"
|
||||
b"\r\n" + err_body
|
||||
)
|
||||
|
||||
# Cache successful GET responses
|
||||
if method == "GET" and not body and response:
|
||||
ttl = ResponseCache.parse_ttl(response, url)
|
||||
if ttl > 0:
|
||||
self._cache.put(url, response, ttl)
|
||||
log.debug("Cached (%ds): %s", ttl, url[:60])
|
||||
|
||||
# Inject permissive CORS headers whenever the browser
|
||||
# sent an Origin (cross-origin XHR / fetch).
|
||||
if origin and response:
|
||||
response = self._inject_cors_headers(response, origin)
|
||||
|
||||
writer.write(response)
|
||||
await writer.drain()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
except asyncio.IncompleteReadError:
|
||||
break
|
||||
except ConnectionError:
|
||||
break
|
||||
except Exception as e:
|
||||
log.error("MITM handler error (%s): %s", host, e)
|
||||
break
|
||||
|
||||
# ── CORS helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _cors_preflight_response(origin: str, acr_method: str, acr_headers: str) -> bytes:
|
||||
"""Return a 204 No Content response that satisfies a CORS preflight."""
|
||||
allow_origin = origin or "*"
|
||||
allow_methods = (
|
||||
f"{acr_method}, GET, POST, PUT, DELETE, PATCH, OPTIONS"
|
||||
if acr_method else
|
||||
"GET, POST, PUT, DELETE, PATCH, OPTIONS"
|
||||
)
|
||||
allow_headers = acr_headers or "*"
|
||||
return (
|
||||
"HTTP/1.1 204 No Content\r\n"
|
||||
f"Access-Control-Allow-Origin: {allow_origin}\r\n"
|
||||
f"Access-Control-Allow-Methods: {allow_methods}\r\n"
|
||||
f"Access-Control-Allow-Headers: {allow_headers}\r\n"
|
||||
"Access-Control-Allow-Credentials: true\r\n"
|
||||
"Access-Control-Max-Age: 86400\r\n"
|
||||
"Vary: Origin\r\n"
|
||||
"Content-Length: 0\r\n"
|
||||
"\r\n"
|
||||
).encode()
|
||||
|
||||
@staticmethod
|
||||
def _inject_cors_headers(response: bytes, origin: str) -> bytes:
|
||||
"""Inject CORS headers only if the upstream response lacks them.
|
||||
|
||||
We must NOT overwrite the origin server's CORS headers: sites like
|
||||
x.com return carefully-scoped Access-Control-Allow-Headers that list
|
||||
specific custom headers (e.g. x-csrf-token). Replacing them with
|
||||
wildcards together with Allow-Credentials: true makes browsers
|
||||
reject the response (per the Fetch spec, "*" is literal when
|
||||
credentials are included), which the site then blames on privacy
|
||||
extensions. So we only fill in what the server omitted.
|
||||
"""
|
||||
sep = b"\r\n\r\n"
|
||||
if sep not in response:
|
||||
return response
|
||||
header_section, body = response.split(sep, 1)
|
||||
lines = header_section.decode(errors="replace").split("\r\n")
|
||||
|
||||
existing = {ln.split(":", 1)[0].strip().lower()
|
||||
for ln in lines if ":" in ln}
|
||||
|
||||
# If the upstream already handled CORS, leave it completely alone.
|
||||
if "access-control-allow-origin" in existing:
|
||||
return response
|
||||
|
||||
# Otherwise inject a minimal, credential-safe set (no wildcards,
|
||||
# since wildcards combined with credentials are invalid).
|
||||
allow_origin = origin or "*"
|
||||
additions = [f"Access-Control-Allow-Origin: {allow_origin}"]
|
||||
if allow_origin != "*":
|
||||
additions.append("Access-Control-Allow-Credentials: true")
|
||||
additions.append("Vary: Origin")
|
||||
return ("\r\n".join(lines + additions) + "\r\n\r\n").encode() + body
|
||||
|
||||
async def _relay_smart(self, method, url, headers, body):
|
||||
"""Choose optimal relay strategy based on request type.
|
||||
|
||||
- GET requests for likely-large downloads use parallel-range.
|
||||
- All other requests (API calls, HTML, JSON, XHR) go through the
|
||||
single-request relay. This avoids injecting a synthetic Range
|
||||
header on normal traffic, which some origins honor by returning
|
||||
206 — breaking fetch()/XHR on sites like x.com or Cloudflare
|
||||
challenge pages.
|
||||
"""
|
||||
if method == "GET" and not body:
|
||||
# Respect client's own Range header verbatim.
|
||||
if headers:
|
||||
for k in headers:
|
||||
if k.lower() == "range":
|
||||
return await self.fronter.relay(
|
||||
method, url, headers, body
|
||||
)
|
||||
# Only probe with Range when the URL looks like a big file.
|
||||
if self._is_likely_download(url, headers):
|
||||
return await self.fronter.relay_parallel(
|
||||
method, url, headers, body
|
||||
)
|
||||
return await self.fronter.relay(method, url, headers, body)
|
||||
|
||||
def _is_likely_download(self, url: str, headers: dict) -> bool:
|
||||
"""Heuristic: is this URL likely a large file download?"""
|
||||
# Check file extension
|
||||
path = url.split("?")[0].lower()
|
||||
large_exts = {
|
||||
".zip", ".tar", ".gz", ".bz2", ".xz", ".7z", ".rar",
|
||||
".exe", ".msi", ".dmg", ".deb", ".rpm", ".apk",
|
||||
".iso", ".img",
|
||||
".mp4", ".mkv", ".avi", ".mov", ".webm",
|
||||
".mp3", ".flac", ".wav", ".aac",
|
||||
".pdf", ".doc", ".docx", ".ppt", ".pptx",
|
||||
".wasm",
|
||||
}
|
||||
for ext in large_exts:
|
||||
if path.endswith(ext):
|
||||
return True
|
||||
return False
|
||||
|
||||
# ── Plain HTTP forwarding ─────────────────────────────────────
|
||||
|
||||
async def _do_http(self, header_block: bytes, reader, writer):
|
||||
body = b""
|
||||
for raw_line in header_block.split(b"\r\n"):
|
||||
if raw_line.lower().startswith(b"content-length:"):
|
||||
length = int(raw_line.split(b":", 1)[1].strip())
|
||||
body = await reader.readexactly(length)
|
||||
break
|
||||
|
||||
first_line = header_block.split(b"\r\n")[0].decode(errors="replace")
|
||||
log.info("HTTP → %s", first_line)
|
||||
|
||||
if self.mode == "apps_script":
|
||||
# Parse request and relay through Apps Script
|
||||
parts = first_line.strip().split(" ", 2)
|
||||
method = parts[0] if parts else "GET"
|
||||
url = parts[1] if len(parts) > 1 else "/"
|
||||
|
||||
headers = {}
|
||||
for raw_line in header_block.split(b"\r\n")[1:]:
|
||||
if b":" in raw_line:
|
||||
k, v = raw_line.decode(errors="replace").split(":", 1)
|
||||
headers[k.strip()] = v.strip()
|
||||
|
||||
# ── CORS preflight over plain HTTP ────────────────────────────
|
||||
origin = next(
|
||||
(v for k, v in headers.items() if k.lower() == "origin"), ""
|
||||
)
|
||||
acr_method = next(
|
||||
(v for k, v in headers.items()
|
||||
if k.lower() == "access-control-request-method"), ""
|
||||
)
|
||||
acr_headers_val = next(
|
||||
(v for k, v in headers.items()
|
||||
if k.lower() == "access-control-request-headers"), ""
|
||||
)
|
||||
if method.upper() == "OPTIONS" and acr_method:
|
||||
log.debug("CORS preflight (HTTP) → %s (responding locally)", url[:60])
|
||||
writer.write(self._cors_preflight_response(origin, acr_method, acr_headers_val))
|
||||
await writer.drain()
|
||||
return
|
||||
|
||||
# Cache check for GET
|
||||
response = None
|
||||
if method == "GET" and not body:
|
||||
response = self._cache.get(url)
|
||||
if response:
|
||||
log.debug("Cache HIT (HTTP): %s", url[:60])
|
||||
|
||||
if response is None:
|
||||
response = await self._relay_smart(method, url, headers, body)
|
||||
# Cache successful GET
|
||||
if method == "GET" and not body and response:
|
||||
ttl = ResponseCache.parse_ttl(response, url)
|
||||
if ttl > 0:
|
||||
self._cache.put(url, response, ttl)
|
||||
|
||||
# Inject CORS headers for cross-origin requests
|
||||
if origin and response:
|
||||
response = self._inject_cors_headers(response, origin)
|
||||
elif self.mode in ("google_fronting", "custom_domain", "domain_fronting"):
|
||||
# Use WebSocket tunnel for ALL traffic (much faster than forward())
|
||||
response = await self._tunnel_http(header_block, body)
|
||||
else:
|
||||
response = await self.fronter.forward(header_block + body)
|
||||
|
||||
writer.write(response)
|
||||
await writer.drain()
|
||||
|
||||
async def _tunnel_http(self, header_block: bytes, body: bytes) -> bytes:
|
||||
"""Forward plain HTTP via a persistent WebSocket tunnel.
|
||||
|
||||
Instead of opening a new TLS+HTTP connection for each request
|
||||
(the old forward() path), this keeps a WebSocket tunnel open
|
||||
to the target host and pipes raw HTTP through it.
|
||||
Much faster for rapid-fire requests (e.g., Telegram API).
|
||||
"""
|
||||
import re as _re
|
||||
|
||||
# Parse target host:port from the raw HTTP request
|
||||
host = ""
|
||||
port = 80
|
||||
for line in header_block.split(b"\r\n")[1:]:
|
||||
if not line:
|
||||
break
|
||||
if line.lower().startswith(b"host:"):
|
||||
host_val = line.split(b":", 1)[1].strip().decode(errors="replace")
|
||||
if ":" in host_val:
|
||||
h, p = host_val.rsplit(":", 1)
|
||||
try:
|
||||
host, port = h, int(p)
|
||||
except ValueError:
|
||||
host = host_val
|
||||
else:
|
||||
host = host_val
|
||||
break
|
||||
|
||||
if not host:
|
||||
return b"HTTP/1.1 400 Bad Request\r\n\r\nNo Host header\r\n"
|
||||
|
||||
# Rewrite the request line: browser sends absolute URL
|
||||
# (e.g., "GET http://host/path HTTP/1.1") but the target
|
||||
# server expects a relative path ("GET /path HTTP/1.1")
|
||||
first_line = header_block.split(b"\r\n")[0]
|
||||
first_str = first_line.decode(errors="replace")
|
||||
parts = first_str.split(" ", 2)
|
||||
if len(parts) >= 2 and parts[1].startswith("http://"):
|
||||
from urllib.parse import urlparse
|
||||
parsed = urlparse(parts[1])
|
||||
rel_path = parsed.path or "/"
|
||||
if parsed.query:
|
||||
rel_path += "?" + parsed.query
|
||||
new_first = f"{parts[0]} {rel_path}"
|
||||
if len(parts) == 3:
|
||||
new_first += f" {parts[2]}"
|
||||
header_block = new_first.encode() + b"\r\n" + b"\r\n".join(header_block.split(b"\r\n")[1:])
|
||||
|
||||
raw_request = header_block + body
|
||||
|
||||
# Send through tunnel
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self.fronter.forward(raw_request), timeout=30
|
||||
)
|
||||
except Exception as e:
|
||||
log.error("Tunnel HTTP failed (%s:%d): %s", host, port, e)
|
||||
return b"HTTP/1.1 502 Bad Gateway\r\n\r\nTunnel forward failed\r\n"
|
||||
-76
@@ -1,76 +0,0 @@
|
||||
"""
|
||||
Minimal WebSocket frame encoder / decoder (RFC 6455).
|
||||
|
||||
Only handles binary (opcode 0x02) and close (opcode 0x08) frames.
|
||||
Client-to-server frames are always masked as required by the spec.
|
||||
"""
|
||||
|
||||
import os
|
||||
import struct
|
||||
|
||||
|
||||
def ws_encode(data: bytes, opcode: int = 0x02) -> bytes:
|
||||
"""Encode *data* into a masked binary WebSocket frame."""
|
||||
head = bytearray([0x80 | opcode]) # FIN + opcode
|
||||
|
||||
length = len(data)
|
||||
if length < 126:
|
||||
head.append(0x80 | length)
|
||||
elif length < 0x10000:
|
||||
head.append(0x80 | 126)
|
||||
head += struct.pack("!H", length)
|
||||
else:
|
||||
head.append(0x80 | 127)
|
||||
head += struct.pack("!Q", length)
|
||||
|
||||
mask = os.urandom(4)
|
||||
head += mask
|
||||
|
||||
masked = bytearray(data)
|
||||
for i in range(len(masked)):
|
||||
masked[i] ^= mask[i & 3]
|
||||
|
||||
return bytes(head) + bytes(masked)
|
||||
|
||||
|
||||
def ws_decode(buf: bytes):
|
||||
"""Try to decode one frame from *buf*.
|
||||
|
||||
Returns ``(opcode, payload, consumed_bytes)`` or ``None`` if the
|
||||
buffer does not yet contain a complete frame.
|
||||
"""
|
||||
if len(buf) < 2:
|
||||
return None
|
||||
|
||||
opcode = buf[0] & 0x0F
|
||||
is_masked = buf[1] & 0x80
|
||||
length = buf[1] & 0x7F
|
||||
pos = 2
|
||||
|
||||
if length == 126:
|
||||
if len(buf) < 4:
|
||||
return None
|
||||
length = struct.unpack("!H", buf[2:4])[0]
|
||||
pos = 4
|
||||
elif length == 127:
|
||||
if len(buf) < 10:
|
||||
return None
|
||||
length = struct.unpack("!Q", buf[2:10])[0]
|
||||
pos = 10
|
||||
|
||||
mask = None
|
||||
if is_masked:
|
||||
if len(buf) < pos + 4:
|
||||
return None
|
||||
mask = buf[pos : pos + 4]
|
||||
pos += 4
|
||||
|
||||
if len(buf) < pos + length:
|
||||
return None
|
||||
|
||||
payload = bytearray(buf[pos : pos + length])
|
||||
if mask:
|
||||
for i in range(len(payload)):
|
||||
payload[i] ^= mask[i & 3]
|
||||
|
||||
return opcode, bytes(payload), pos + length
|
||||
@@ -6,7 +6,7 @@ Also attempts to install into Firefox's NSS certificate store when found.
|
||||
|
||||
Usage:
|
||||
from cert_installer import install_ca, is_ca_trusted
|
||||
install_ca("/path/to/ca.crt", cert_name="MHR_CFW")
|
||||
install_ca("/path/to/ca.crt", cert_name="mhr-cfw")
|
||||
"""
|
||||
|
||||
import glob
|
||||
@@ -18,7 +18,7 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
log = logging.getLogger("CertInstaller")
|
||||
log = logging.getLogger("Cert")
|
||||
|
||||
CERT_NAME = "mhr-cfw"
|
||||
|
||||
@@ -250,28 +250,68 @@ def _install_linux(cert_path: str, cert_name: str) -> bool:
|
||||
return installed
|
||||
|
||||
|
||||
def _is_trusted_linux(cert_path: str) -> bool:
|
||||
"""Check if our cert thumbprint is in the system's OpenSSL trust bundle."""
|
||||
thumbprint = _cert_thumbprint(cert_path)
|
||||
if not thumbprint:
|
||||
def _is_trusted_linux(cert_path: str, cert_name: str = CERT_NAME) -> bool:
|
||||
"""Check whether the cert appears in common Linux trust stores."""
|
||||
try:
|
||||
from cryptography import x509 as _x509
|
||||
from cryptography.hazmat.primitives import hashes as _hashes
|
||||
except Exception:
|
||||
return False
|
||||
bundle_paths = [
|
||||
"/etc/ssl/certs/ca-certificates.crt", # Debian/Ubuntu
|
||||
"/etc/pki/tls/certs/ca-bundle.crt", # RHEL/Fedora
|
||||
"/etc/ssl/ca-bundle.pem", # OpenSUSE
|
||||
"/etc/ca-certificates/ca-certificates.crt",
|
||||
]
|
||||
# A fast heuristic: check if our CA cert file was copied to known dirs
|
||||
|
||||
try:
|
||||
with open(cert_path, "rb") as f:
|
||||
target_cert = _x509.load_pem_x509_certificate(f.read())
|
||||
target_fp = target_cert.fingerprint(_hashes.SHA1())
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# First check the common anchor locations used by the installer.
|
||||
expected_name = f"{cert_name.replace(' ', '_')}.crt"
|
||||
anchor_dirs = [
|
||||
"/usr/local/share/ca-certificates",
|
||||
"/etc/pki/ca-trust/source/anchors",
|
||||
"/etc/ca-certificates/trust-source/anchors",
|
||||
]
|
||||
for d in anchor_dirs:
|
||||
if os.path.isdir(d):
|
||||
for f in os.listdir(d):
|
||||
if "DomainFront" in f or "domainfront" in f.lower():
|
||||
try:
|
||||
if not os.path.isdir(d):
|
||||
continue
|
||||
if expected_name in os.listdir(d):
|
||||
return True
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Fall back to scanning the system bundle files directly.
|
||||
bundle_paths = [
|
||||
"/etc/ssl/certs/ca-certificates.crt", # Debian/Ubuntu
|
||||
"/etc/pki/tls/certs/ca-bundle.crt", # RHEL/Fedora
|
||||
"/etc/ssl/ca-bundle.pem", # OpenSUSE
|
||||
"/etc/ca-certificates/ca-certificates.crt",
|
||||
]
|
||||
|
||||
begin = b"-----BEGIN CERTIFICATE-----"
|
||||
end = b"-----END CERTIFICATE-----"
|
||||
for bundle in bundle_paths:
|
||||
try:
|
||||
with open(bundle, "rb") as f:
|
||||
data = f.read()
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
for chunk in data.split(begin):
|
||||
if end not in chunk:
|
||||
continue
|
||||
pem = begin + chunk.split(end, 1)[0] + end + b"\n"
|
||||
try:
|
||||
cert = _x509.load_pem_x509_certificate(pem)
|
||||
except Exception:
|
||||
continue
|
||||
try:
|
||||
if cert.fingerprint(_hashes.SHA1()) == target_fp:
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
|
||||
@@ -279,43 +319,225 @@ def _is_trusted_linux(cert_path: str) -> bool:
|
||||
# Firefox NSS (cross-platform)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
# def _install_firefox(cert_path: str, cert_name: str):
|
||||
# """Install into all detected Firefox profile NSS databases."""
|
||||
# if not _has_cmd("certutil"):
|
||||
# log.debug("NSS certutil not found — skipping Firefox install.")
|
||||
# return
|
||||
def _install_firefox(cert_path: str, cert_name: str):
|
||||
"""Install into all detected Firefox profile NSS databases."""
|
||||
if not _has_cmd("certutil"):
|
||||
log.debug("NSS certutil not found — skipping Firefox install.")
|
||||
return
|
||||
|
||||
# profile_dirs: list[str] = []
|
||||
# system = platform.system()
|
||||
profile_dirs: list[str] = []
|
||||
system = platform.system()
|
||||
|
||||
# if system == "Windows":
|
||||
# appdata = os.environ.get("APPDATA", "")
|
||||
# profile_dirs += glob.glob(os.path.join(appdata, r"Mozilla\Firefox\Profiles\*"))
|
||||
# elif system == "Darwin":
|
||||
# profile_dirs += glob.glob(os.path.expanduser("~/Library/Application Support/Firefox/Profiles/*"))
|
||||
# else:
|
||||
# profile_dirs += glob.glob(os.path.expanduser("~/.mozilla/firefox/*.default*"))
|
||||
# profile_dirs += glob.glob(os.path.expanduser("~/.mozilla/firefox/*.release*"))
|
||||
if system == "Windows":
|
||||
appdata = os.environ.get("APPDATA", "")
|
||||
profile_dirs += glob.glob(os.path.join(appdata, r"Mozilla\Firefox\Profiles\*"))
|
||||
elif system == "Darwin":
|
||||
profile_dirs += glob.glob(os.path.expanduser("~/Library/Application Support/Firefox/Profiles/*"))
|
||||
else:
|
||||
profile_dirs += glob.glob(os.path.expanduser("~/.mozilla/firefox/*.default*"))
|
||||
profile_dirs += glob.glob(os.path.expanduser("~/.mozilla/firefox/*.release*"))
|
||||
|
||||
# if not profile_dirs:
|
||||
# log.debug("No Firefox profiles found.")
|
||||
# return
|
||||
if not profile_dirs:
|
||||
log.debug("No Firefox profiles found.")
|
||||
return
|
||||
|
||||
# for profile in profile_dirs:
|
||||
# db = f"sql:{profile}" if os.path.exists(os.path.join(profile, "cert9.db")) else f"dbm:{profile}"
|
||||
# try:
|
||||
# # Remove old entry first (ignore errors)
|
||||
# _run(["certutil", "-D", "-n", cert_name, "-d", db], check=False)
|
||||
# _run([
|
||||
# "certutil", "-A",
|
||||
# "-n", cert_name,
|
||||
# "-t", "CT,,",
|
||||
# "-i", cert_path,
|
||||
# "-d", db,
|
||||
# ])
|
||||
# log.info("Installed in Firefox profile: %s", os.path.basename(profile))
|
||||
# except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
||||
# log.warning("Firefox profile %s: %s", os.path.basename(profile), exc)
|
||||
for profile in profile_dirs:
|
||||
db = f"sql:{profile}" if os.path.exists(os.path.join(profile, "cert9.db")) else f"dbm:{profile}"
|
||||
try:
|
||||
# Remove old entry first (ignore errors)
|
||||
_run(["certutil", "-D", "-n", cert_name, "-d", db], check=False)
|
||||
_run([
|
||||
"certutil", "-A",
|
||||
"-n", cert_name,
|
||||
"-t", "CT,,",
|
||||
"-i", cert_path,
|
||||
"-d", db,
|
||||
])
|
||||
log.info("Installed in Firefox profile: %s", os.path.basename(profile))
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
||||
log.warning("Firefox profile %s: %s", os.path.basename(profile), exc)
|
||||
|
||||
|
||||
def _uninstall_firefox(cert_name: str):
|
||||
"""Remove certificate from all detected Firefox profile NSS databases."""
|
||||
if not _has_cmd("certutil"):
|
||||
log.debug("NSS certutil not found — skipping Firefox uninstall.")
|
||||
return
|
||||
|
||||
profile_dirs: list[str] = []
|
||||
system = platform.system()
|
||||
|
||||
if system == "Windows":
|
||||
appdata = os.environ.get("APPDATA", "")
|
||||
profile_dirs += glob.glob(os.path.join(appdata, r"Mozilla\Firefox\Profiles\*"))
|
||||
elif system == "Darwin":
|
||||
profile_dirs += glob.glob(os.path.expanduser("~/Library/Application Support/Firefox/Profiles/*"))
|
||||
else:
|
||||
profile_dirs += glob.glob(os.path.expanduser("~/.mozilla/firefox/*.default*"))
|
||||
profile_dirs += glob.glob(os.path.expanduser("~/.mozilla/firefox/*.release*"))
|
||||
|
||||
if not profile_dirs:
|
||||
log.debug("No Firefox profiles found.")
|
||||
return
|
||||
|
||||
for profile in profile_dirs:
|
||||
db = f"sql:{profile}" if os.path.exists(os.path.join(profile, "cert9.db")) else f"dbm:{profile}"
|
||||
try:
|
||||
result = _run(["certutil", "-D", "-n", cert_name, "-d", db], check=False)
|
||||
if result.returncode == 0:
|
||||
log.info("Removed from Firefox profile: %s", os.path.basename(profile))
|
||||
else:
|
||||
log.debug("Firefox profile %s: certificate not present", os.path.basename(profile))
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
||||
log.debug("Firefox profile %s: %s", os.path.basename(profile), exc)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Uninstall functions
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
def _uninstall_windows(cert_path: str, cert_name: str) -> bool:
|
||||
"""Remove certificate from the Windows Trusted Root store."""
|
||||
thumbprint = _cert_thumbprint(cert_path)
|
||||
|
||||
# Try per-user store first (no admin required)
|
||||
try:
|
||||
target = thumbprint if thumbprint else cert_name
|
||||
_run(["certutil", "-delstore", "-user", "Root", target])
|
||||
log.info("Certificate removed from Windows user Trusted Root store.")
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
||||
log.warning("certutil user store removal failed: %s", exc)
|
||||
|
||||
# Try system store (requires admin)
|
||||
try:
|
||||
target = thumbprint if thumbprint else cert_name
|
||||
_run(["certutil", "-delstore", "Root", target])
|
||||
log.info("Certificate removed from Windows system Trusted Root store.")
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
||||
log.warning("certutil system store removal failed: %s", exc)
|
||||
|
||||
# Fallback: use PowerShell
|
||||
try:
|
||||
if thumbprint:
|
||||
ps_cmd = (
|
||||
"Get-ChildItem Cert:\\CurrentUser\\Root | "
|
||||
f"Where-Object {{ $_.Thumbprint -eq '{thumbprint}' }} | "
|
||||
"Remove-Item -Force -ErrorAction SilentlyContinue"
|
||||
)
|
||||
else:
|
||||
ps_cmd = (
|
||||
"Get-ChildItem Cert:\\CurrentUser\\Root | "
|
||||
f"Where-Object {{ $_.Subject -like '*CN={cert_name}*' -or $_.FriendlyName -eq '{cert_name}' }} | "
|
||||
"Remove-Item -Force -ErrorAction SilentlyContinue"
|
||||
)
|
||||
_run(["powershell", "-NoProfile", "-Command", ps_cmd])
|
||||
log.info("Certificate removal via PowerShell completed.")
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
||||
log.error("PowerShell removal failed: %s", exc)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _uninstall_macos(cert_name: str) -> bool:
|
||||
"""Remove certificate from the macOS keychains."""
|
||||
login_keychain = os.path.expanduser("~/Library/Keychains/login.keychain-db")
|
||||
if not os.path.exists(login_keychain):
|
||||
login_keychain = os.path.expanduser("~/Library/Keychains/login.keychain")
|
||||
|
||||
try:
|
||||
_run([
|
||||
"security", "delete-certificate",
|
||||
"-c", cert_name,
|
||||
login_keychain,
|
||||
])
|
||||
log.info("Certificate removed from macOS login keychain.")
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
||||
log.warning("login keychain removal failed: %s", exc)
|
||||
|
||||
# Try system keychain (needs sudo)
|
||||
try:
|
||||
_run([
|
||||
"sudo", "security", "delete-certificate",
|
||||
"-c", cert_name,
|
||||
"/Library/Keychains/System.keychain",
|
||||
])
|
||||
log.info("Certificate removed from macOS system keychain.")
|
||||
return True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
|
||||
log.debug("System keychain removal failed: %s", exc)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _uninstall_linux(cert_path: str, cert_name: str) -> bool:
|
||||
"""Remove certificate from Linux trust stores."""
|
||||
distro = _detect_linux_distro()
|
||||
log.info("Detected Linux distro family: %s", distro)
|
||||
|
||||
removed = False
|
||||
|
||||
if distro == "debian":
|
||||
dest_file = f"/usr/local/share/ca-certificates/{cert_name.replace(' ', '_')}.crt"
|
||||
try:
|
||||
if os.path.exists(dest_file):
|
||||
os.remove(dest_file)
|
||||
_run(["update-ca-certificates"])
|
||||
log.info("Certificate removed via update-ca-certificates.")
|
||||
removed = True
|
||||
except (OSError, subprocess.CalledProcessError) as exc:
|
||||
log.warning("Debian removal failed (needs sudo?): %s", exc)
|
||||
try:
|
||||
_run(["sudo", "rm", "-f", dest_file])
|
||||
_run(["sudo", "update-ca-certificates"])
|
||||
log.info("Certificate removed via sudo update-ca-certificates.")
|
||||
removed = True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as exc2:
|
||||
log.warning("sudo Debian removal failed: %s", exc2)
|
||||
|
||||
elif distro == "rhel":
|
||||
dest_file = f"/etc/pki/ca-trust/source/anchors/{cert_name.replace(' ', '_')}.crt"
|
||||
try:
|
||||
if os.path.exists(dest_file):
|
||||
os.remove(dest_file)
|
||||
_run(["update-ca-trust", "extract"])
|
||||
log.info("Certificate removed via update-ca-trust.")
|
||||
removed = True
|
||||
except (OSError, subprocess.CalledProcessError) as exc:
|
||||
log.warning("RHEL removal failed (needs sudo?): %s", exc)
|
||||
try:
|
||||
_run(["sudo", "rm", "-f", dest_file])
|
||||
_run(["sudo", "update-ca-trust", "extract"])
|
||||
log.info("Certificate removed via sudo update-ca-trust.")
|
||||
removed = True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as exc2:
|
||||
log.warning("sudo RHEL removal failed: %s", exc2)
|
||||
|
||||
elif distro == "arch":
|
||||
dest_file = f"/etc/ca-certificates/trust-source/anchors/{cert_name.replace(' ', '_')}.crt"
|
||||
try:
|
||||
if os.path.exists(dest_file):
|
||||
os.remove(dest_file)
|
||||
_run(["trust", "extract-compat"])
|
||||
log.info("Certificate removed via trust extract-compat.")
|
||||
removed = True
|
||||
except (OSError, subprocess.CalledProcessError) as exc:
|
||||
log.warning("Arch removal failed (needs sudo?): %s", exc)
|
||||
try:
|
||||
_run(["sudo", "rm", "-f", dest_file])
|
||||
_run(["sudo", "trust", "extract-compat"])
|
||||
log.info("Certificate removed via sudo trust extract-compat.")
|
||||
removed = True
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as exc2:
|
||||
log.warning("sudo Arch removal failed: %s", exc2)
|
||||
|
||||
else:
|
||||
log.warning("Unknown Linux distro. Manually remove %s from trusted CAs.", cert_name)
|
||||
|
||||
return removed
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
@@ -330,7 +552,7 @@ def is_ca_trusted(cert_path: str) -> bool:
|
||||
return _is_trusted_windows(cert_path)
|
||||
if system == "Darwin":
|
||||
return _is_trusted_macos(CERT_NAME)
|
||||
return _is_trusted_linux(cert_path)
|
||||
return _is_trusted_linux(cert_path, CERT_NAME)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -360,6 +582,32 @@ def install_ca(cert_path: str, cert_name: str = CERT_NAME) -> bool:
|
||||
return False
|
||||
|
||||
# Best-effort Firefox install on all platforms
|
||||
# _install_firefox(cert_path, cert_name)
|
||||
_install_firefox(cert_path, cert_name)
|
||||
|
||||
return ok
|
||||
|
||||
|
||||
def uninstall_ca(cert_path: str, cert_name: str = CERT_NAME) -> bool:
|
||||
"""
|
||||
Remove *cert_name* from the system's trusted root CAs on the current platform.
|
||||
Also attempts Firefox NSS removal.
|
||||
|
||||
Returns True if the system store removal succeeded.
|
||||
"""
|
||||
system = platform.system()
|
||||
log.info("Removing CA certificate from %s…", system)
|
||||
|
||||
if system == "Windows":
|
||||
ok = _uninstall_windows(cert_path, cert_name)
|
||||
elif system == "Darwin":
|
||||
ok = _uninstall_macos(cert_name)
|
||||
elif system == "Linux":
|
||||
ok = _uninstall_linux(cert_path, cert_name)
|
||||
else:
|
||||
log.error("Unsupported platform: %s", system)
|
||||
return False
|
||||
|
||||
# Best-effort Firefox uninstall on all platforms
|
||||
_uninstall_firefox(cert_name)
|
||||
|
||||
return ok
|
||||
Reference in New Issue
Block a user