add src/cert_installer.py

This commit is contained in:
anthroposcene
2026-04-29 02:55:11 -07:00
parent 06d35504b1
commit aa220c720d
6 changed files with 299 additions and 2706 deletions
File diff suppressed because it is too large Load Diff
-419
View File
@@ -1,419 +0,0 @@
"""
HTTP/2 multiplexed transport for domain-fronted connections.
One TLS connection → many concurrent HTTP/2 streams → massive throughput.
Eliminates per-request TLS handshake overhead entirely.
Instead of a pool of 30 HTTP/1.1 connections (each handling 1 request),
this uses a SINGLE HTTP/2 connection handling 100+ concurrent requests.
Performance comparison:
HTTP/1.1 pool: 30 connections × 1 request = 30 concurrent requests max
HTTP/2 mux: 1 connection × 100 streams = 100 concurrent requests
Requires: pip install h2
"""
import asyncio
import gzip
import logging
import socket
import ssl
from urllib.parse import urlparse
log = logging.getLogger("H2")
try:
import h2.connection
import h2.config
import h2.events
import h2.settings
H2_AVAILABLE = True
except ImportError:
H2_AVAILABLE = False
class _StreamState:
"""State for a single in-flight HTTP/2 stream."""
__slots__ = ("status", "headers", "data", "done", "error")
def __init__(self):
self.status = 0
self.headers: dict[str, str] = {}
self.data = bytearray()
self.done = asyncio.Event()
self.error: str | None = None
class H2Transport:
"""
Persistent HTTP/2 connection with automatic stream multiplexing.
All relay requests share ONE TLS connection. Each request becomes
an independent HTTP/2 stream, running fully concurrently.
Features:
- Auto-connect on first use
- Auto-reconnect on connection loss
- Redirect following (as new streams, same connection)
- Gzip decompression
- Configurable max concurrency
"""
def __init__(self, connect_host: str, sni_host: str,
verify_ssl: bool = True):
self.connect_host = connect_host
self.sni_host = sni_host
self.verify_ssl = verify_ssl
self._reader: asyncio.StreamReader | None = None
self._writer: asyncio.StreamWriter | None = None
self._h2: "h2.connection.H2Connection | None" = None
self._connected = False
self._write_lock = asyncio.Lock()
self._connect_lock = asyncio.Lock()
self._read_task: asyncio.Task | None = None
# Per-stream tracking
self._streams: dict[int, _StreamState] = {}
# Stats
self.total_requests = 0
self.total_streams = 0
# ── Connection lifecycle ──────────────────────────────────────
@property
def is_connected(self) -> bool:
return self._connected
async def ensure_connected(self):
"""Connect if not already connected."""
if self._connected:
return
async with self._connect_lock:
if self._connected:
return
await self._do_connect()
async def _do_connect(self):
"""Establish the HTTP/2 connection with optimized socket settings."""
ctx = ssl.create_default_context()
# Advertise both h2 and http/1.1 — some DPI blocks h2-only ALPN
ctx.set_alpn_protocols(["h2", "http/1.1"])
if not self.verify_ssl:
ctx.check_hostname = False
ctx.verify_mode = ssl.CERT_NONE
# Create raw TCP socket with TCP_NODELAY BEFORE TLS handshake.
# Nagle's algorithm can delay small writes (H2 frames) by up to 200ms
# waiting to coalesce — TCP_NODELAY forces immediate send.
raw = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
raw.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
raw.setblocking(False)
try:
await asyncio.wait_for(
asyncio.get_event_loop().sock_connect(
raw, (self.connect_host, 443)
),
timeout=15,
)
self._reader, self._writer = await asyncio.wait_for(
asyncio.open_connection(
ssl=ctx,
server_hostname=self.sni_host,
sock=raw,
),
timeout=15,
)
except Exception:
raw.close()
raise
# Verify we actually got HTTP/2
ssl_obj = self._writer.get_extra_info("ssl_object")
negotiated = ssl_obj.selected_alpn_protocol() if ssl_obj else None
if negotiated != "h2":
self._writer.close()
raise RuntimeError(
f"H2 ALPN negotiation failed (got {negotiated!r})"
)
config = h2.config.H2Configuration(
client_side=True,
header_encoding="utf-8",
)
self._h2 = h2.connection.H2Connection(config=config)
self._h2.initiate_connection()
# Connection-level flow control: ~16MB window
self._h2.increment_flow_control_window(2 ** 24 - 65535)
# Per-stream settings: 1MB initial window, disable server push
self._h2.update_settings({
h2.settings.SettingCodes.INITIAL_WINDOW_SIZE: 1 * 1024 * 1024,
h2.settings.SettingCodes.ENABLE_PUSH: 0,
})
await self._flush()
self._connected = True
self._read_task = asyncio.create_task(self._reader_loop())
log.info("H2 connected → %s (SNI=%s, TCP_NODELAY=on)",
self.connect_host, self.sni_host)
async def reconnect(self):
"""Close current connection and re-establish."""
await self._close_internal()
await self._do_connect()
async def _close_internal(self):
self._connected = False
if self._read_task:
self._read_task.cancel()
self._read_task = None
if self._writer:
try:
self._writer.close()
except Exception:
pass
self._writer = None
# Wake all pending streams so they can raise
for state in self._streams.values():
state.error = "Connection closed"
state.done.set()
self._streams.clear()
# ── Public API ────────────────────────────────────────────────
async def request(self, method: str, path: str, host: str,
headers: dict | None = None,
body: bytes | None = None,
timeout: float = 25,
follow_redirects: int = 5) -> tuple[int, dict, bytes]:
"""
Send an HTTP/2 request and return (status, headers, body).
Thread-safe: many concurrent calls each get their own stream.
Redirects are followed as new streams on the same connection.
"""
await self.ensure_connected()
self.total_requests += 1
for _ in range(follow_redirects + 1):
status, resp_headers, resp_body = await self._single_request(
method, path, host, headers, body, timeout,
)
if status not in (301, 302, 303, 307, 308):
return status, resp_headers, resp_body
location = resp_headers.get("location", "")
if not location:
return status, resp_headers, resp_body
parsed = urlparse(location)
path = parsed.path + ("?" + parsed.query if parsed.query else "")
host = parsed.netloc or host
method = "GET"
body = None
headers = None # Drop request headers on redirect
return status, resp_headers, resp_body
# ── Stream handling ───────────────────────────────────────────
async def _single_request(self, method, path, host, headers, body,
timeout) -> tuple[int, dict, bytes]:
"""Send one HTTP/2 request on a new stream, wait for response."""
if not self._connected:
await self.ensure_connected()
stream_id = None
async with self._write_lock:
try:
stream_id = self._h2.get_next_available_stream_id()
except Exception:
# Connection is stale — reconnect
await self.reconnect()
stream_id = self._h2.get_next_available_stream_id()
h2_headers = [
(":method", method),
(":path", path),
(":authority", host),
(":scheme", "https"),
("accept-encoding", "gzip"),
]
if headers:
for k, v in headers.items():
h2_headers.append((k.lower(), str(v)))
end_stream = not body
self._h2.send_headers(stream_id, h2_headers, end_stream=end_stream)
if body:
# Send body (may need chunking for flow control)
self._send_body(stream_id, body)
state = _StreamState()
self._streams[stream_id] = state
self.total_streams += 1
await self._flush()
# Wait for complete response
try:
await asyncio.wait_for(state.done.wait(), timeout=timeout)
except asyncio.TimeoutError:
self._streams.pop(stream_id, None)
raise TimeoutError(
f"H2 stream {stream_id} timed out ({timeout}s)"
)
self._streams.pop(stream_id, None)
if state.error:
raise ConnectionError(f"H2 stream error: {state.error}")
# Auto-decompress gzip
resp_body = bytes(state.data)
if state.headers.get("content-encoding", "").lower() == "gzip":
try:
resp_body = gzip.decompress(resp_body)
except Exception:
pass
return state.status, state.headers, resp_body
def _send_body(self, stream_id: int, body: bytes):
"""Send request body, respecting H2 flow control window."""
# For small bodies (typical JSON payloads), send in one shot
while body:
max_size = self._h2.local_settings.max_frame_size
window = self._h2.local_flow_control_window(stream_id)
send_size = min(len(body), max_size, window)
if send_size <= 0:
# Flow control full — let the reader loop process
# window updates before we continue
break
end = send_size >= len(body)
self._h2.send_data(stream_id, body[:send_size], end_stream=end)
body = body[send_size:]
# ── Background reader ─────────────────────────────────────────
async def _reader_loop(self):
"""Background: read H2 frames, dispatch events to waiting streams."""
try:
while self._connected:
data = await self._reader.read(65536)
if not data:
log.warning("H2 remote closed connection")
break
try:
events = self._h2.receive_data(data)
except Exception as e:
log.error("H2 protocol error: %s", e)
break
for event in events:
self._dispatch(event)
# Send pending data (acks, window updates, ping responses)
async with self._write_lock:
await self._flush()
except asyncio.CancelledError:
pass
except Exception as e:
log.error("H2 reader error: %s", e)
finally:
self._connected = False
for state in self._streams.values():
if not state.done.is_set():
state.error = "Connection lost"
state.done.set()
log.info("H2 reader loop ended")
def _dispatch(self, event):
"""Route a single h2 event to its stream."""
if isinstance(event, h2.events.ResponseReceived):
state = self._streams.get(event.stream_id)
if state:
for name, value in event.headers:
n = name if isinstance(name, str) else name.decode()
v = value if isinstance(value, str) else value.decode()
if n == ":status":
state.status = int(v)
else:
state.headers[n] = v
elif isinstance(event, h2.events.DataReceived):
state = self._streams.get(event.stream_id)
if state:
state.data.extend(event.data)
# Always acknowledge received data for flow control
self._h2.acknowledge_received_data(
event.flow_controlled_length, event.stream_id
)
elif isinstance(event, h2.events.StreamEnded):
state = self._streams.get(event.stream_id)
if state:
state.done.set()
elif isinstance(event, h2.events.StreamReset):
state = self._streams.get(event.stream_id)
if state:
state.error = f"Stream reset (code={event.error_code})"
state.done.set()
elif isinstance(event, h2.events.WindowUpdated):
pass # h2 library handles window bookkeeping
elif isinstance(event, h2.events.SettingsAcknowledged):
pass
elif isinstance(event, h2.events.PingReceived):
pass # h2 library auto-responds
elif isinstance(event, h2.events.PingAckReceived):
pass # keepalive confirmed
# ── Internal ──────────────────────────────────────────────────
async def _flush(self):
"""Write pending H2 frame data to the socket."""
data = self._h2.data_to_send()
if data and self._writer:
self._writer.write(data)
await self._writer.drain()
async def close(self):
"""Gracefully close the HTTP/2 connection."""
if self._h2 and self._connected:
try:
self._h2.close_connection()
async with self._write_lock:
await self._flush()
except Exception:
pass
await self._close_internal()
async def ping(self):
"""Send an H2 PING frame to keep the connection alive."""
if not self._connected or not self._h2:
return
try:
async with self._write_lock:
if not self._connected:
return
self._h2.ping(b"\x00" * 8)
await self._flush()
except Exception as e:
log.debug("H2 PING failed: %s", e)
-153
View File
@@ -1,153 +0,0 @@
"""
MITM certificate manager for HTTPS interception.
Generates a CA certificate (once, stored as files) and per-domain
certificates (on the fly, cached in memory) so the local proxy can
decrypt HTTPS traffic and relay it through Apps Script.
The user must install ca/ca.crt in their browser's trusted CAs once.
Requires: pip install cryptography
"""
import datetime
import logging
import os
import ssl
import tempfile
from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID
log = logging.getLogger("MITM")
CA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../cert")
CA_KEY_FILE = os.path.join(CA_DIR, "ca.key")
CA_CERT_FILE = os.path.join(CA_DIR, "ca.crt")
class MITMCertManager:
def __init__(self):
self._ca_key = None
self._ca_cert = None
self._ctx_cache: dict[str, ssl.SSLContext] = {}
self._cert_dir = tempfile.mkdtemp(prefix="domainfront_certs_")
self._ensure_ca()
def _ensure_ca(self):
if os.path.exists(CA_KEY_FILE) and os.path.exists(CA_CERT_FILE):
with open(CA_KEY_FILE, "rb") as f:
self._ca_key = serialization.load_pem_private_key(
f.read(), password=None
)
with open(CA_CERT_FILE, "rb") as f:
self._ca_cert = x509.load_pem_x509_certificate(f.read())
log.info("Loaded CA from %s", CA_DIR)
else:
self._create_ca()
def _create_ca(self):
os.makedirs(CA_DIR, exist_ok=True)
self._ca_key = rsa.generate_private_key(
public_exponent=65537, key_size=2048
)
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, "MHR_CFW"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "MHR_CFW"),
])
now = datetime.datetime.now(datetime.timezone.utc)
self._ca_cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer)
.public_key(self._ca_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + datetime.timedelta(days=3650))
.add_extension(
x509.BasicConstraints(ca=True, path_length=0), critical=True
)
.add_extension(
x509.KeyUsage(
digital_signature=True,
key_cert_sign=True,
crl_sign=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
encipher_only=False,
decipher_only=False,
),
critical=True,
)
.sign(self._ca_key, hashes.SHA256())
)
with open(CA_KEY_FILE, "wb") as f:
f.write(
self._ca_key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.TraditionalOpenSSL,
serialization.NoEncryption(),
)
)
with open(CA_CERT_FILE, "wb") as f:
f.write(self._ca_cert.public_bytes(serialization.Encoding.PEM))
log.warning("Generated new CA certificate: %s", CA_CERT_FILE)
log.warning(">>> Install this file in your browser's Trusted Root CAs! <<<")
def get_server_context(self, domain: str) -> ssl.SSLContext:
if domain not in self._ctx_cache:
key_pem, cert_pem = self._generate_domain_cert(domain)
cert_file = os.path.join(self._cert_dir, f"{domain}.crt")
key_file = os.path.join(self._cert_dir, f"{domain}.key")
ca_pem = self._ca_cert.public_bytes(serialization.Encoding.PEM)
with open(cert_file, "wb") as f:
f.write(cert_pem + ca_pem)
with open(key_file, "wb") as f:
f.write(key_pem)
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ctx.set_alpn_protocols(["http/1.1"])
ctx.load_cert_chain(cert_file, key_file)
self._ctx_cache[domain] = ctx
return self._ctx_cache[domain]
def _generate_domain_cert(self, domain: str):
key = rsa.generate_private_key(
public_exponent=65537, key_size=2048
)
subject = x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, domain),
])
now = datetime.datetime.now(datetime.timezone.utc)
cert = (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(self._ca_cert.subject)
.public_key(key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(now)
.not_valid_after(now + datetime.timedelta(days=365))
.add_extension(
x509.SubjectAlternativeName([x509.DNSName(domain)]),
critical=False,
)
.sign(self._ca_key, hashes.SHA256())
)
key_pem = key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.TraditionalOpenSSL,
serialization.NoEncryption(),
)
cert_pem = cert.public_bytes(serialization.Encoding.PEM)
return key_pem, cert_pem
-777
View File
@@ -1,777 +0,0 @@
"""
Local HTTP proxy server.
Intercepts the user's browser traffic and forwards everything through
a domain-fronted connection to a CDN worker or Apps Script relay.
Supports:
- CONNECT method → WebSocket tunnel (modes 1-3) or MITM relay (apps_script)
- GET / POST etc. → HTTP forwarding (modes 1-3) or JSON relay (apps_script)
"""
import asyncio
import logging
import re
import ssl
import time
from core.domain_fronter import DomainFronter
log = logging.getLogger("Proxy")
class ResponseCache:
"""Simple LRU response cache — avoids repeated relay calls."""
def __init__(self, max_mb: int = 50):
self._store: dict[str, tuple[bytes, float]] = {}
self._size = 0
self._max = max_mb * 1024 * 1024
self.hits = 0
self.misses = 0
def get(self, url: str) -> bytes | None:
entry = self._store.get(url)
if not entry:
self.misses += 1
return None
raw, expires = entry
if time.time() > expires:
self._size -= len(raw)
del self._store[url]
self.misses += 1
return None
self.hits += 1
return raw
def put(self, url: str, raw_response: bytes, ttl: int = 300):
size = len(raw_response)
if size > self._max // 4 or size == 0:
return
# Evict oldest to make room
while self._size + size > self._max and self._store:
oldest = next(iter(self._store))
self._size -= len(self._store[oldest][0])
del self._store[oldest]
if url in self._store:
self._size -= len(self._store[url][0])
self._store[url] = (raw_response, time.time() + ttl)
self._size += size
@staticmethod
def parse_ttl(raw_response: bytes, url: str) -> int:
"""Determine cache TTL from response headers and URL."""
hdr_end = raw_response.find(b"\r\n\r\n")
if hdr_end < 0:
return 0
hdr = raw_response[:hdr_end].decode(errors="replace").lower()
# Don't cache errors or non-200
if b"HTTP/1.1 200" not in raw_response[:20]:
return 0
if "no-store" in hdr:
return 0
# Explicit max-age
m = re.search(r"max-age=(\d+)", hdr)
if m:
return min(int(m.group(1)), 86400)
# Heuristic by content type / extension
path = url.split("?")[0].lower()
static_exts = (
".css", ".js", ".woff", ".woff2", ".ttf", ".eot",
".png", ".jpg", ".jpeg", ".gif", ".webp", ".svg", ".ico",
".mp3", ".mp4", ".wasm",
)
for ext in static_exts:
if path.endswith(ext):
return 3600 # 1 hour for static assets
ct_m = re.search(r"content-type:\s*([^\r\n]+)", hdr)
ct = ct_m.group(1) if ct_m else ""
if "image/" in ct or "font/" in ct:
return 3600
if "text/css" in ct or "javascript" in ct:
return 1800
if "text/html" in ct or "application/json" in ct:
return 0 # don't cache dynamic content by default
return 0
class ProxyServer:
def __init__(self, config: dict):
self.host = config.get("listen_host", "127.0.0.1")
self.port = config.get("listen_port", 8080)
self.mode = config.get("mode", "domain_fronting")
self.fronter = DomainFronter(config)
self.mitm = None
self._cache = ResponseCache(max_mb=50)
# Persistent HTTP tunnel cache for google_fronting mode
# Key: "host:port" → (tunnel_reader, tunnel_writer, lock)
self._http_tunnels: dict = {}
self._tunnel_lock = asyncio.Lock()
# hosts override — DNS fake-map: domain/suffix → IP
# Checked before any real DNS lookup; supports exact and suffix matching.
self._hosts: dict[str, str] = config.get("hosts", {})
if self.mode == "apps_script":
try:
from core.mitm import MITMCertManager
self.mitm = MITMCertManager()
except ImportError:
log.error("apps_script mode requires 'cryptography' package.")
log.error("Run: pip install cryptography")
raise SystemExit(1)
async def start(self):
srv = await asyncio.start_server(self._on_client, self.host, self.port)
log.info(
"Listening on %s:%d — configure your browser HTTP proxy to this address",
self.host, self.port,
)
async with srv:
await srv.serve_forever()
# ── client handler ────────────────────────────────────────────
async def _on_client(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
addr = writer.get_extra_info("peername")
try:
first_line = await asyncio.wait_for(reader.readline(), timeout=30)
if not first_line:
return
# Read remaining headers
header_block = first_line
while True:
line = await asyncio.wait_for(reader.readline(), timeout=10)
header_block += line
if line in (b"\r\n", b"\n", b""):
break
request_line = first_line.decode(errors="replace").strip()
parts = request_line.split(" ", 2)
if len(parts) < 2:
return
method = parts[0].upper()
if method == "CONNECT":
await self._do_connect(parts[1], reader, writer)
else:
await self._do_http(header_block, reader, writer)
except asyncio.TimeoutError:
log.debug("Timeout: %s", addr)
except Exception as e:
log.error("Error (%s): %s", addr, e)
finally:
try:
writer.close()
await writer.wait_closed()
except Exception:
pass
# ── CONNECT (HTTPS tunnelling) ────────────────────────────────
async def _do_connect(self, target: str, reader, writer):
host, _, port = target.rpartition(":")
port = int(port) if port else 443
if not host:
host, port = target, 443
log.info("CONNECT → %s:%d", host, port)
writer.write(b"HTTP/1.1 200 Connection Established\r\n\r\n")
await writer.drain()
if self.mode == "apps_script":
override_ip = self._sni_rewrite_ip(host)
if override_ip:
# SNI-blocked domain: MITM-decrypt from browser, then
# re-connect to the override IP with SNI=front_domain so
# the ISP never sees the blocked hostname in the TLS handshake.
log.info("SNI-rewrite tunnel → %s via %s (SNI: %s)",
host, override_ip, self.fronter.sni_host)
await self._do_sni_rewrite_tunnel(host, port, reader, writer,
connect_ip=override_ip)
elif self._is_google_domain(host):
log.info("Direct tunnel → %s (Google domain, skipping relay)", host)
await self._do_direct_tunnel(host, port, reader, writer)
else:
await self._do_mitm_connect(host, port, reader, writer)
else:
await self.fronter.tunnel(host, port, reader, writer)
# ── Hosts override (fake DNS) ─────────────────────────────────
# Built-in list of domains that must be reached via Google's frontend IP
# with SNI rewritten to `front_domain` (default: www.google.com).
# These are Google-owned services whose real SNI is DPI-blocked in some
# countries, but that Google serves from the same edge IP as www.google.com.
# Users don't need to configure anything — any host matching one of these
# suffixes is transparently SNI-rewritten to the configured `google_ip`.
# Config's "hosts" map still takes precedence (for custom overrides).
_SNI_REWRITE_SUFFIXES = (
"youtube.com",
"youtu.be",
"youtube-nocookie.com",
"ytimg.com",
"ggpht.com",
"gvt1.com",
"gvt2.com",
"doubleclick.net",
"googlesyndication.com",
"googleadservices.com",
"google-analytics.com",
"googletagmanager.com",
"googletagservices.com",
"fonts.googleapis.com",
)
def _sni_rewrite_ip(self, host: str) -> str | None:
"""Return the IP to SNI-rewrite `host` through, or None.
Order of precedence:
1. Explicit entry in config `hosts` map (exact or suffix match).
2. Built-in `_SNI_REWRITE_SUFFIXES` → mapped to config `google_ip`.
"""
ip = self._hosts_ip(host)
if ip:
return ip
h = host.lower().rstrip(".")
for suffix in self._SNI_REWRITE_SUFFIXES:
if h == suffix or h.endswith("." + suffix):
return self.fronter.connect_host # configured google_ip
return None
def _hosts_ip(self, host: str) -> str | None:
"""Return override IP for host if defined in config 'hosts', else None.
Supports exact match and suffix match (e.g. 'youtube.com' matches
'www.youtube.com', 'm.youtube.com', etc.).
"""
h = host.lower().rstrip(".")
if h in self._hosts:
return self._hosts[h]
# suffix match: check every parent label
parts = h.split(".")
for i in range(1, len(parts)):
parent = ".".join(parts[i:])
if parent in self._hosts:
return self._hosts[parent]
return None
# ── Google domain detection ───────────────────────────────────
# Only domains whose SNI the ISP does NOT block — direct tunnel is safe.
# YouTube/googlevideo SNIs are blocked; they go through _do_sni_rewrite_tunnel
# via the hosts map instead.
_GOOGLE_SUFFIXES = (
".google.com", ".google.co",
".googleapis.com", ".gstatic.com",
".googleusercontent.com",
)
_GOOGLE_EXACT = {
"google.com", "gstatic.com", "googleapis.com",
}
def _is_google_domain(self, host: str) -> bool:
"""Return True if host is a Google-owned domain."""
h = host.lower().rstrip(".")
if h in self._GOOGLE_EXACT:
return True
for suffix in self._GOOGLE_SUFFIXES:
if h.endswith(suffix):
return True
return False
# ── Direct tunnel (no MITM) ───────────────────────────────────
async def _do_direct_tunnel(self, host: str, port: int,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
connect_ip: str | None = None):
"""Pipe raw TLS bytes directly to the target server.
connect_ip overrides DNS: the TCP connection goes to that IP
while the browser's TLS (SNI=host) is piped through unchanged.
Defaults to the configured google_ip for Google-category domains.
"""
target_ip = connect_ip or self.fronter.connect_host
try:
r_remote, w_remote = await asyncio.wait_for(
asyncio.open_connection(target_ip, port), timeout=10
)
except Exception as e:
log.error("Direct tunnel connect failed (%s via %s): %s",
host, target_ip, e)
return
async def pipe(src, dst, label):
try:
while True:
data = await src.read(65536)
if not data:
break
dst.write(data)
await dst.drain()
except (ConnectionError, asyncio.CancelledError):
pass
except Exception as e:
log.debug("Pipe %s ended: %s", label, e)
finally:
try:
dst.close()
except Exception:
pass
await asyncio.gather(
pipe(reader, w_remote, f"client→{host}"),
pipe(r_remote, writer, f"{host}→client"),
)
# ── SNI-rewrite tunnel ────────────────────────────────────────
async def _do_sni_rewrite_tunnel(self, host: str, port: int, reader, writer,
connect_ip: str | None = None):
"""MITM-decrypt TLS from browser, then re-encrypt toward connect_ip
using SNI=front_domain (e.g. www.google.com).
The ISP only ever sees SNI=www.google.com in the outgoing handshake,
hiding the blocked hostname (e.g. www.youtube.com).
"""
target_ip = connect_ip or self.fronter.connect_host
sni_out = self.fronter.sni_host # e.g. "www.google.com"
# Step 1: MITM — accept TLS from the browser
ssl_ctx_server = self.mitm.get_server_context(host)
loop = asyncio.get_event_loop()
transport = writer.transport
protocol = transport.get_protocol()
try:
new_transport = await loop.start_tls(
transport, protocol, ssl_ctx_server, server_side=True,
)
except Exception as e:
log.debug("SNI-rewrite TLS accept failed (%s): %s", host, e)
return
writer._transport = new_transport
# Step 2: open outgoing TLS to target IP with the safe SNI
ssl_ctx_client = ssl.create_default_context()
if not self.fronter.verify_ssl:
ssl_ctx_client.check_hostname = False
ssl_ctx_client.verify_mode = ssl.CERT_NONE
try:
r_out, w_out = await asyncio.wait_for(
asyncio.open_connection(
target_ip, port,
ssl=ssl_ctx_client,
server_hostname=sni_out,
),
timeout=10,
)
except Exception as e:
log.error("SNI-rewrite outbound connect failed (%s via %s): %s",
host, target_ip, e)
return
# Step 3: pipe application-layer bytes between the two TLS sessions
async def pipe(src, dst, label):
try:
while True:
data = await src.read(65536)
if not data:
break
dst.write(data)
await dst.drain()
except (ConnectionError, asyncio.CancelledError):
pass
except Exception as exc:
log.debug("Pipe %s ended: %s", label, exc)
finally:
try:
dst.close()
except Exception:
pass
await asyncio.gather(
pipe(reader, w_out, f"client→{host}"),
pipe(r_out, writer, f"{host}→client"),
)
# ── MITM CONNECT (apps_script mode) ───────────────────────────
async def _do_mitm_connect(self, host: str, port: int, reader, writer):
"""Intercept TLS, decrypt HTTP, and relay through Apps Script."""
ssl_ctx = self.mitm.get_server_context(host)
# Upgrade the existing connection to TLS (we are the server)
loop = asyncio.get_event_loop()
transport = writer.transport
protocol = transport.get_protocol()
try:
new_transport = await loop.start_tls(
transport, protocol, ssl_ctx, server_side=True,
)
except Exception as e:
# Non-HTTPS traffic (e.g. MTProto, plain HTTP on port 80/443)
# routed through the proxy will always fail TLS — log at DEBUG
# to avoid alarming noise.
if port != 443:
log.debug("TLS handshake skipped for %s:%d (non-HTTPS): %s", host, port, e)
else:
log.debug("TLS handshake failed for %s: %s", host, e)
return
# Update writer to use the new TLS transport
writer._transport = new_transport
# Read and relay HTTP requests from the browser (now decrypted)
while True:
try:
first_line = await asyncio.wait_for(reader.readline(), timeout=120)
if not first_line:
break
header_block = first_line
while True:
line = await asyncio.wait_for(reader.readline(), timeout=10)
header_block += line
if line in (b"\r\n", b"\n", b""):
break
# Read body
body = b""
for raw_line in header_block.split(b"\r\n"):
if raw_line.lower().startswith(b"content-length:"):
length = int(raw_line.split(b":", 1)[1].strip())
body = await reader.readexactly(length)
break
# Parse the request
request_line = first_line.decode(errors="replace").strip()
parts = request_line.split(" ", 2)
if len(parts) < 2:
break
method = parts[0]
path = parts[1]
# Parse headers
headers = {}
for raw_line in header_block.split(b"\r\n")[1:]:
if b":" in raw_line:
k, v = raw_line.decode(errors="replace").split(":", 1)
headers[k.strip()] = v.strip()
# Build full URL (browser sends just the path in CONNECT)
if port == 443:
url = f"https://{host}{path}"
else:
url = f"https://{host}:{port}{path}"
log.info("MITM → %s %s", method, url)
# ── CORS: extract relevant request headers ────────────────────
origin = next(
(v for k, v in headers.items() if k.lower() == "origin"), ""
)
acr_method = next(
(v for k, v in headers.items()
if k.lower() == "access-control-request-method"), ""
)
acr_headers = next(
(v for k, v in headers.items()
if k.lower() == "access-control-request-headers"), ""
)
# CORS preflight — respond directly; UrlFetchApp doesn't
# support OPTIONS so forwarding it would always fail.
if method.upper() == "OPTIONS" and acr_method:
log.debug("CORS preflight → %s (responding locally)", url[:60])
writer.write(self._cors_preflight_response(origin, acr_method, acr_headers))
await writer.drain()
continue
# Check local cache first (GET only)
response = None
if method == "GET" and not body:
response = self._cache.get(url)
if response:
log.debug("Cache HIT: %s", url[:60])
if response is None:
# Relay through Apps Script
try:
response = await self._relay_smart(method, url, headers, body)
except Exception as e:
log.error("Relay error (%s): %s", url[:60], e)
err_body = f"Relay error: {e}".encode()
response = (
b"HTTP/1.1 502 Bad Gateway\r\n"
b"Content-Type: text/plain\r\n"
b"Content-Length: " + str(len(err_body)).encode() + b"\r\n"
b"\r\n" + err_body
)
# Cache successful GET responses
if method == "GET" and not body and response:
ttl = ResponseCache.parse_ttl(response, url)
if ttl > 0:
self._cache.put(url, response, ttl)
log.debug("Cached (%ds): %s", ttl, url[:60])
# Inject permissive CORS headers whenever the browser
# sent an Origin (cross-origin XHR / fetch).
if origin and response:
response = self._inject_cors_headers(response, origin)
writer.write(response)
await writer.drain()
except asyncio.TimeoutError:
break
except asyncio.IncompleteReadError:
break
except ConnectionError:
break
except Exception as e:
log.error("MITM handler error (%s): %s", host, e)
break
# ── CORS helpers ──────────────────────────────────────────────────────────
@staticmethod
def _cors_preflight_response(origin: str, acr_method: str, acr_headers: str) -> bytes:
"""Return a 204 No Content response that satisfies a CORS preflight."""
allow_origin = origin or "*"
allow_methods = (
f"{acr_method}, GET, POST, PUT, DELETE, PATCH, OPTIONS"
if acr_method else
"GET, POST, PUT, DELETE, PATCH, OPTIONS"
)
allow_headers = acr_headers or "*"
return (
"HTTP/1.1 204 No Content\r\n"
f"Access-Control-Allow-Origin: {allow_origin}\r\n"
f"Access-Control-Allow-Methods: {allow_methods}\r\n"
f"Access-Control-Allow-Headers: {allow_headers}\r\n"
"Access-Control-Allow-Credentials: true\r\n"
"Access-Control-Max-Age: 86400\r\n"
"Vary: Origin\r\n"
"Content-Length: 0\r\n"
"\r\n"
).encode()
@staticmethod
def _inject_cors_headers(response: bytes, origin: str) -> bytes:
"""Inject CORS headers only if the upstream response lacks them.
We must NOT overwrite the origin server's CORS headers: sites like
x.com return carefully-scoped Access-Control-Allow-Headers that list
specific custom headers (e.g. x-csrf-token). Replacing them with
wildcards together with Allow-Credentials: true makes browsers
reject the response (per the Fetch spec, "*" is literal when
credentials are included), which the site then blames on privacy
extensions. So we only fill in what the server omitted.
"""
sep = b"\r\n\r\n"
if sep not in response:
return response
header_section, body = response.split(sep, 1)
lines = header_section.decode(errors="replace").split("\r\n")
existing = {ln.split(":", 1)[0].strip().lower()
for ln in lines if ":" in ln}
# If the upstream already handled CORS, leave it completely alone.
if "access-control-allow-origin" in existing:
return response
# Otherwise inject a minimal, credential-safe set (no wildcards,
# since wildcards combined with credentials are invalid).
allow_origin = origin or "*"
additions = [f"Access-Control-Allow-Origin: {allow_origin}"]
if allow_origin != "*":
additions.append("Access-Control-Allow-Credentials: true")
additions.append("Vary: Origin")
return ("\r\n".join(lines + additions) + "\r\n\r\n").encode() + body
async def _relay_smart(self, method, url, headers, body):
"""Choose optimal relay strategy based on request type.
- GET requests for likely-large downloads use parallel-range.
- All other requests (API calls, HTML, JSON, XHR) go through the
single-request relay. This avoids injecting a synthetic Range
header on normal traffic, which some origins honor by returning
206 — breaking fetch()/XHR on sites like x.com or Cloudflare
challenge pages.
"""
if method == "GET" and not body:
# Respect client's own Range header verbatim.
if headers:
for k in headers:
if k.lower() == "range":
return await self.fronter.relay(
method, url, headers, body
)
# Only probe with Range when the URL looks like a big file.
if self._is_likely_download(url, headers):
return await self.fronter.relay_parallel(
method, url, headers, body
)
return await self.fronter.relay(method, url, headers, body)
def _is_likely_download(self, url: str, headers: dict) -> bool:
"""Heuristic: is this URL likely a large file download?"""
# Check file extension
path = url.split("?")[0].lower()
large_exts = {
".zip", ".tar", ".gz", ".bz2", ".xz", ".7z", ".rar",
".exe", ".msi", ".dmg", ".deb", ".rpm", ".apk",
".iso", ".img",
".mp4", ".mkv", ".avi", ".mov", ".webm",
".mp3", ".flac", ".wav", ".aac",
".pdf", ".doc", ".docx", ".ppt", ".pptx",
".wasm",
}
for ext in large_exts:
if path.endswith(ext):
return True
return False
# ── Plain HTTP forwarding ─────────────────────────────────────
async def _do_http(self, header_block: bytes, reader, writer):
body = b""
for raw_line in header_block.split(b"\r\n"):
if raw_line.lower().startswith(b"content-length:"):
length = int(raw_line.split(b":", 1)[1].strip())
body = await reader.readexactly(length)
break
first_line = header_block.split(b"\r\n")[0].decode(errors="replace")
log.info("HTTP → %s", first_line)
if self.mode == "apps_script":
# Parse request and relay through Apps Script
parts = first_line.strip().split(" ", 2)
method = parts[0] if parts else "GET"
url = parts[1] if len(parts) > 1 else "/"
headers = {}
for raw_line in header_block.split(b"\r\n")[1:]:
if b":" in raw_line:
k, v = raw_line.decode(errors="replace").split(":", 1)
headers[k.strip()] = v.strip()
# ── CORS preflight over plain HTTP ────────────────────────────
origin = next(
(v for k, v in headers.items() if k.lower() == "origin"), ""
)
acr_method = next(
(v for k, v in headers.items()
if k.lower() == "access-control-request-method"), ""
)
acr_headers_val = next(
(v for k, v in headers.items()
if k.lower() == "access-control-request-headers"), ""
)
if method.upper() == "OPTIONS" and acr_method:
log.debug("CORS preflight (HTTP) → %s (responding locally)", url[:60])
writer.write(self._cors_preflight_response(origin, acr_method, acr_headers_val))
await writer.drain()
return
# Cache check for GET
response = None
if method == "GET" and not body:
response = self._cache.get(url)
if response:
log.debug("Cache HIT (HTTP): %s", url[:60])
if response is None:
response = await self._relay_smart(method, url, headers, body)
# Cache successful GET
if method == "GET" and not body and response:
ttl = ResponseCache.parse_ttl(response, url)
if ttl > 0:
self._cache.put(url, response, ttl)
# Inject CORS headers for cross-origin requests
if origin and response:
response = self._inject_cors_headers(response, origin)
elif self.mode in ("google_fronting", "custom_domain", "domain_fronting"):
# Use WebSocket tunnel for ALL traffic (much faster than forward())
response = await self._tunnel_http(header_block, body)
else:
response = await self.fronter.forward(header_block + body)
writer.write(response)
await writer.drain()
async def _tunnel_http(self, header_block: bytes, body: bytes) -> bytes:
"""Forward plain HTTP via a persistent WebSocket tunnel.
Instead of opening a new TLS+HTTP connection for each request
(the old forward() path), this keeps a WebSocket tunnel open
to the target host and pipes raw HTTP through it.
Much faster for rapid-fire requests (e.g., Telegram API).
"""
import re as _re
# Parse target host:port from the raw HTTP request
host = ""
port = 80
for line in header_block.split(b"\r\n")[1:]:
if not line:
break
if line.lower().startswith(b"host:"):
host_val = line.split(b":", 1)[1].strip().decode(errors="replace")
if ":" in host_val:
h, p = host_val.rsplit(":", 1)
try:
host, port = h, int(p)
except ValueError:
host = host_val
else:
host = host_val
break
if not host:
return b"HTTP/1.1 400 Bad Request\r\n\r\nNo Host header\r\n"
# Rewrite the request line: browser sends absolute URL
# (e.g., "GET http://host/path HTTP/1.1") but the target
# server expects a relative path ("GET /path HTTP/1.1")
first_line = header_block.split(b"\r\n")[0]
first_str = first_line.decode(errors="replace")
parts = first_str.split(" ", 2)
if len(parts) >= 2 and parts[1].startswith("http://"):
from urllib.parse import urlparse
parsed = urlparse(parts[1])
rel_path = parsed.path or "/"
if parsed.query:
rel_path += "?" + parsed.query
new_first = f"{parts[0]} {rel_path}"
if len(parts) == 3:
new_first += f" {parts[2]}"
header_block = new_first.encode() + b"\r\n" + b"\r\n".join(header_block.split(b"\r\n")[1:])
raw_request = header_block + body
# Send through tunnel
try:
return await asyncio.wait_for(
self.fronter.forward(raw_request), timeout=30
)
except Exception as e:
log.error("Tunnel HTTP failed (%s:%d): %s", host, port, e)
return b"HTTP/1.1 502 Bad Gateway\r\n\r\nTunnel forward failed\r\n"
-76
View File
@@ -1,76 +0,0 @@
"""
Minimal WebSocket frame encoder / decoder (RFC 6455).
Only handles binary (opcode 0x02) and close (opcode 0x08) frames.
Client-to-server frames are always masked as required by the spec.
"""
import os
import struct
def ws_encode(data: bytes, opcode: int = 0x02) -> bytes:
"""Encode *data* into a masked binary WebSocket frame."""
head = bytearray([0x80 | opcode]) # FIN + opcode
length = len(data)
if length < 126:
head.append(0x80 | length)
elif length < 0x10000:
head.append(0x80 | 126)
head += struct.pack("!H", length)
else:
head.append(0x80 | 127)
head += struct.pack("!Q", length)
mask = os.urandom(4)
head += mask
masked = bytearray(data)
for i in range(len(masked)):
masked[i] ^= mask[i & 3]
return bytes(head) + bytes(masked)
def ws_decode(buf: bytes):
"""Try to decode one frame from *buf*.
Returns ``(opcode, payload, consumed_bytes)`` or ``None`` if the
buffer does not yet contain a complete frame.
"""
if len(buf) < 2:
return None
opcode = buf[0] & 0x0F
is_masked = buf[1] & 0x80
length = buf[1] & 0x7F
pos = 2
if length == 126:
if len(buf) < 4:
return None
length = struct.unpack("!H", buf[2:4])[0]
pos = 4
elif length == 127:
if len(buf) < 10:
return None
length = struct.unpack("!Q", buf[2:10])[0]
pos = 10
mask = None
if is_masked:
if len(buf) < pos + 4:
return None
mask = buf[pos : pos + 4]
pos += 4
if len(buf) < pos + length:
return None
payload = bytearray(buf[pos : pos + length])
if mask:
for i in range(len(payload)):
payload[i] ^= mask[i & 3]
return opcode, bytes(payload), pos + length
+299 -51
View File
@@ -6,7 +6,7 @@ Also attempts to install into Firefox's NSS certificate store when found.
Usage:
from cert_installer import install_ca, is_ca_trusted
install_ca("/path/to/ca.crt", cert_name="MHR_CFW")
install_ca("/path/to/ca.crt", cert_name="mhr-cfw")
"""
import glob
@@ -18,7 +18,7 @@ import subprocess
import sys
import tempfile
log = logging.getLogger("CertInstaller")
log = logging.getLogger("Cert")
CERT_NAME = "mhr-cfw"
@@ -250,28 +250,68 @@ def _install_linux(cert_path: str, cert_name: str) -> bool:
return installed
def _is_trusted_linux(cert_path: str) -> bool:
"""Check if our cert thumbprint is in the system's OpenSSL trust bundle."""
thumbprint = _cert_thumbprint(cert_path)
if not thumbprint:
def _is_trusted_linux(cert_path: str, cert_name: str = CERT_NAME) -> bool:
"""Check whether the cert appears in common Linux trust stores."""
try:
from cryptography import x509 as _x509
from cryptography.hazmat.primitives import hashes as _hashes
except Exception:
return False
bundle_paths = [
"/etc/ssl/certs/ca-certificates.crt", # Debian/Ubuntu
"/etc/pki/tls/certs/ca-bundle.crt", # RHEL/Fedora
"/etc/ssl/ca-bundle.pem", # OpenSUSE
"/etc/ca-certificates/ca-certificates.crt",
]
# A fast heuristic: check if our CA cert file was copied to known dirs
try:
with open(cert_path, "rb") as f:
target_cert = _x509.load_pem_x509_certificate(f.read())
target_fp = target_cert.fingerprint(_hashes.SHA1())
except Exception:
return False
# First check the common anchor locations used by the installer.
expected_name = f"{cert_name.replace(' ', '_')}.crt"
anchor_dirs = [
"/usr/local/share/ca-certificates",
"/etc/pki/ca-trust/source/anchors",
"/etc/ca-certificates/trust-source/anchors",
]
for d in anchor_dirs:
if os.path.isdir(d):
for f in os.listdir(d):
if "DomainFront" in f or "domainfront" in f.lower():
try:
if not os.path.isdir(d):
continue
if expected_name in os.listdir(d):
return True
except OSError:
pass
# Fall back to scanning the system bundle files directly.
bundle_paths = [
"/etc/ssl/certs/ca-certificates.crt", # Debian/Ubuntu
"/etc/pki/tls/certs/ca-bundle.crt", # RHEL/Fedora
"/etc/ssl/ca-bundle.pem", # OpenSUSE
"/etc/ca-certificates/ca-certificates.crt",
]
begin = b"-----BEGIN CERTIFICATE-----"
end = b"-----END CERTIFICATE-----"
for bundle in bundle_paths:
try:
with open(bundle, "rb") as f:
data = f.read()
except OSError:
continue
for chunk in data.split(begin):
if end not in chunk:
continue
pem = begin + chunk.split(end, 1)[0] + end + b"\n"
try:
cert = _x509.load_pem_x509_certificate(pem)
except Exception:
continue
try:
if cert.fingerprint(_hashes.SHA1()) == target_fp:
return True
except Exception:
continue
return False
@@ -279,43 +319,225 @@ def _is_trusted_linux(cert_path: str) -> bool:
# Firefox NSS (cross-platform)
# ─────────────────────────────────────────────────────────────────────────────
# def _install_firefox(cert_path: str, cert_name: str):
# """Install into all detected Firefox profile NSS databases."""
# if not _has_cmd("certutil"):
# log.debug("NSS certutil not found — skipping Firefox install.")
# return
def _install_firefox(cert_path: str, cert_name: str):
"""Install into all detected Firefox profile NSS databases."""
if not _has_cmd("certutil"):
log.debug("NSS certutil not found — skipping Firefox install.")
return
# profile_dirs: list[str] = []
# system = platform.system()
profile_dirs: list[str] = []
system = platform.system()
# if system == "Windows":
# appdata = os.environ.get("APPDATA", "")
# profile_dirs += glob.glob(os.path.join(appdata, r"Mozilla\Firefox\Profiles\*"))
# elif system == "Darwin":
# profile_dirs += glob.glob(os.path.expanduser("~/Library/Application Support/Firefox/Profiles/*"))
# else:
# profile_dirs += glob.glob(os.path.expanduser("~/.mozilla/firefox/*.default*"))
# profile_dirs += glob.glob(os.path.expanduser("~/.mozilla/firefox/*.release*"))
if system == "Windows":
appdata = os.environ.get("APPDATA", "")
profile_dirs += glob.glob(os.path.join(appdata, r"Mozilla\Firefox\Profiles\*"))
elif system == "Darwin":
profile_dirs += glob.glob(os.path.expanduser("~/Library/Application Support/Firefox/Profiles/*"))
else:
profile_dirs += glob.glob(os.path.expanduser("~/.mozilla/firefox/*.default*"))
profile_dirs += glob.glob(os.path.expanduser("~/.mozilla/firefox/*.release*"))
# if not profile_dirs:
# log.debug("No Firefox profiles found.")
# return
if not profile_dirs:
log.debug("No Firefox profiles found.")
return
# for profile in profile_dirs:
# db = f"sql:{profile}" if os.path.exists(os.path.join(profile, "cert9.db")) else f"dbm:{profile}"
# try:
# # Remove old entry first (ignore errors)
# _run(["certutil", "-D", "-n", cert_name, "-d", db], check=False)
# _run([
# "certutil", "-A",
# "-n", cert_name,
# "-t", "CT,,",
# "-i", cert_path,
# "-d", db,
# ])
# log.info("Installed in Firefox profile: %s", os.path.basename(profile))
# except (subprocess.CalledProcessError, FileNotFoundError) as exc:
# log.warning("Firefox profile %s: %s", os.path.basename(profile), exc)
for profile in profile_dirs:
db = f"sql:{profile}" if os.path.exists(os.path.join(profile, "cert9.db")) else f"dbm:{profile}"
try:
# Remove old entry first (ignore errors)
_run(["certutil", "-D", "-n", cert_name, "-d", db], check=False)
_run([
"certutil", "-A",
"-n", cert_name,
"-t", "CT,,",
"-i", cert_path,
"-d", db,
])
log.info("Installed in Firefox profile: %s", os.path.basename(profile))
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
log.warning("Firefox profile %s: %s", os.path.basename(profile), exc)
def _uninstall_firefox(cert_name: str):
"""Remove certificate from all detected Firefox profile NSS databases."""
if not _has_cmd("certutil"):
log.debug("NSS certutil not found — skipping Firefox uninstall.")
return
profile_dirs: list[str] = []
system = platform.system()
if system == "Windows":
appdata = os.environ.get("APPDATA", "")
profile_dirs += glob.glob(os.path.join(appdata, r"Mozilla\Firefox\Profiles\*"))
elif system == "Darwin":
profile_dirs += glob.glob(os.path.expanduser("~/Library/Application Support/Firefox/Profiles/*"))
else:
profile_dirs += glob.glob(os.path.expanduser("~/.mozilla/firefox/*.default*"))
profile_dirs += glob.glob(os.path.expanduser("~/.mozilla/firefox/*.release*"))
if not profile_dirs:
log.debug("No Firefox profiles found.")
return
for profile in profile_dirs:
db = f"sql:{profile}" if os.path.exists(os.path.join(profile, "cert9.db")) else f"dbm:{profile}"
try:
result = _run(["certutil", "-D", "-n", cert_name, "-d", db], check=False)
if result.returncode == 0:
log.info("Removed from Firefox profile: %s", os.path.basename(profile))
else:
log.debug("Firefox profile %s: certificate not present", os.path.basename(profile))
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
log.debug("Firefox profile %s: %s", os.path.basename(profile), exc)
# ─────────────────────────────────────────────────────────────────────────────
# Uninstall functions
# ─────────────────────────────────────────────────────────────────────────────
def _uninstall_windows(cert_path: str, cert_name: str) -> bool:
"""Remove certificate from the Windows Trusted Root store."""
thumbprint = _cert_thumbprint(cert_path)
# Try per-user store first (no admin required)
try:
target = thumbprint if thumbprint else cert_name
_run(["certutil", "-delstore", "-user", "Root", target])
log.info("Certificate removed from Windows user Trusted Root store.")
return True
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
log.warning("certutil user store removal failed: %s", exc)
# Try system store (requires admin)
try:
target = thumbprint if thumbprint else cert_name
_run(["certutil", "-delstore", "Root", target])
log.info("Certificate removed from Windows system Trusted Root store.")
return True
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
log.warning("certutil system store removal failed: %s", exc)
# Fallback: use PowerShell
try:
if thumbprint:
ps_cmd = (
"Get-ChildItem Cert:\\CurrentUser\\Root | "
f"Where-Object {{ $_.Thumbprint -eq '{thumbprint}' }} | "
"Remove-Item -Force -ErrorAction SilentlyContinue"
)
else:
ps_cmd = (
"Get-ChildItem Cert:\\CurrentUser\\Root | "
f"Where-Object {{ $_.Subject -like '*CN={cert_name}*' -or $_.FriendlyName -eq '{cert_name}' }} | "
"Remove-Item -Force -ErrorAction SilentlyContinue"
)
_run(["powershell", "-NoProfile", "-Command", ps_cmd])
log.info("Certificate removal via PowerShell completed.")
return True
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
log.error("PowerShell removal failed: %s", exc)
return False
def _uninstall_macos(cert_name: str) -> bool:
"""Remove certificate from the macOS keychains."""
login_keychain = os.path.expanduser("~/Library/Keychains/login.keychain-db")
if not os.path.exists(login_keychain):
login_keychain = os.path.expanduser("~/Library/Keychains/login.keychain")
try:
_run([
"security", "delete-certificate",
"-c", cert_name,
login_keychain,
])
log.info("Certificate removed from macOS login keychain.")
return True
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
log.warning("login keychain removal failed: %s", exc)
# Try system keychain (needs sudo)
try:
_run([
"sudo", "security", "delete-certificate",
"-c", cert_name,
"/Library/Keychains/System.keychain",
])
log.info("Certificate removed from macOS system keychain.")
return True
except (subprocess.CalledProcessError, FileNotFoundError) as exc:
log.debug("System keychain removal failed: %s", exc)
return False
def _uninstall_linux(cert_path: str, cert_name: str) -> bool:
"""Remove certificate from Linux trust stores."""
distro = _detect_linux_distro()
log.info("Detected Linux distro family: %s", distro)
removed = False
if distro == "debian":
dest_file = f"/usr/local/share/ca-certificates/{cert_name.replace(' ', '_')}.crt"
try:
if os.path.exists(dest_file):
os.remove(dest_file)
_run(["update-ca-certificates"])
log.info("Certificate removed via update-ca-certificates.")
removed = True
except (OSError, subprocess.CalledProcessError) as exc:
log.warning("Debian removal failed (needs sudo?): %s", exc)
try:
_run(["sudo", "rm", "-f", dest_file])
_run(["sudo", "update-ca-certificates"])
log.info("Certificate removed via sudo update-ca-certificates.")
removed = True
except (subprocess.CalledProcessError, FileNotFoundError) as exc2:
log.warning("sudo Debian removal failed: %s", exc2)
elif distro == "rhel":
dest_file = f"/etc/pki/ca-trust/source/anchors/{cert_name.replace(' ', '_')}.crt"
try:
if os.path.exists(dest_file):
os.remove(dest_file)
_run(["update-ca-trust", "extract"])
log.info("Certificate removed via update-ca-trust.")
removed = True
except (OSError, subprocess.CalledProcessError) as exc:
log.warning("RHEL removal failed (needs sudo?): %s", exc)
try:
_run(["sudo", "rm", "-f", dest_file])
_run(["sudo", "update-ca-trust", "extract"])
log.info("Certificate removed via sudo update-ca-trust.")
removed = True
except (subprocess.CalledProcessError, FileNotFoundError) as exc2:
log.warning("sudo RHEL removal failed: %s", exc2)
elif distro == "arch":
dest_file = f"/etc/ca-certificates/trust-source/anchors/{cert_name.replace(' ', '_')}.crt"
try:
if os.path.exists(dest_file):
os.remove(dest_file)
_run(["trust", "extract-compat"])
log.info("Certificate removed via trust extract-compat.")
removed = True
except (OSError, subprocess.CalledProcessError) as exc:
log.warning("Arch removal failed (needs sudo?): %s", exc)
try:
_run(["sudo", "rm", "-f", dest_file])
_run(["sudo", "trust", "extract-compat"])
log.info("Certificate removed via sudo trust extract-compat.")
removed = True
except (subprocess.CalledProcessError, FileNotFoundError) as exc2:
log.warning("sudo Arch removal failed: %s", exc2)
else:
log.warning("Unknown Linux distro. Manually remove %s from trusted CAs.", cert_name)
return removed
# ─────────────────────────────────────────────────────────────────────────────
@@ -330,7 +552,7 @@ def is_ca_trusted(cert_path: str) -> bool:
return _is_trusted_windows(cert_path)
if system == "Darwin":
return _is_trusted_macos(CERT_NAME)
return _is_trusted_linux(cert_path)
return _is_trusted_linux(cert_path, CERT_NAME)
except Exception:
return False
@@ -360,6 +582,32 @@ def install_ca(cert_path: str, cert_name: str = CERT_NAME) -> bool:
return False
# Best-effort Firefox install on all platforms
# _install_firefox(cert_path, cert_name)
_install_firefox(cert_path, cert_name)
return ok
def uninstall_ca(cert_path: str, cert_name: str = CERT_NAME) -> bool:
"""
Remove *cert_name* from the system's trusted root CAs on the current platform.
Also attempts Firefox NSS removal.
Returns True if the system store removal succeeded.
"""
system = platform.system()
log.info("Removing CA certificate from %s", system)
if system == "Windows":
ok = _uninstall_windows(cert_path, cert_name)
elif system == "Darwin":
ok = _uninstall_macos(cert_name)
elif system == "Linux":
ok = _uninstall_linux(cert_path, cert_name)
else:
log.error("Unsupported platform: %s", system)
return False
# Best-effort Firefox uninstall on all platforms
_uninstall_firefox(cert_name)
return ok