fix: replace deprecated get_event_loop() with get_running_loop() in async functions

This commit is contained in:
Abolfazl
2026-04-23 22:38:59 +03:30
parent 1df9cf4d68
commit bca757a46a
3 changed files with 46 additions and 45 deletions
+38 -37
View File
@@ -202,7 +202,7 @@ class DomainFronter:
we rotate across `self._sni_hosts` so DPI can't fingerprint we rotate across `self._sni_hosts` so DPI can't fingerprint
"always www.google.com" from the client side. "always www.google.com" from the client side.
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.setblocking(False) sock.setblocking(False)
@@ -228,7 +228,7 @@ class DomainFronter:
async def _acquire(self): async def _acquire(self):
"""Get a healthy TLS connection from pool (TTL-checked) or open new.""" """Get a healthy TLS connection from pool (TTL-checked) or open new."""
now = asyncio.get_event_loop().time() now = asyncio.get_running_loop().time()
async with self._pool_lock: async with self._pool_lock:
while self._pool: while self._pool:
reader, writer, created = self._pool.pop() reader, writer, created = self._pool.pop()
@@ -247,11 +247,11 @@ class DomainFronter:
if not self._refilling: if not self._refilling:
self._refilling = True self._refilling = True
self._spawn(self._refill_pool()) self._spawn(self._refill_pool())
return reader, writer, asyncio.get_event_loop().time() return reader, writer, asyncio.get_running_loop().time()
async def _release(self, reader, writer, created): async def _release(self, reader, writer, created):
"""Return a connection to the pool if still young and healthy.""" """Return a connection to the pool if still young and healthy."""
now = asyncio.get_event_loop().time() now = asyncio.get_running_loop().time()
if (now - created) >= self._conn_ttl or reader.at_eof(): if (now - created) >= self._conn_ttl or reader.at_eof():
try: try:
writer.close() writer.close()
@@ -442,6 +442,7 @@ class DomainFronter:
def _exec_path_for_sid(self, sid: str) -> str: def _exec_path_for_sid(self, sid: str) -> str:
"""Build the /macros/s/<sid>/(dev|exec) path for a specific script ID.""" """Build the /macros/s/<sid>/(dev|exec) path for a specific script ID."""
return f"/macros/s/{sid}/{'dev' if self._dev_available else 'exec'}" return f"/macros/s/{sid}/{'dev' if self._dev_available else 'exec'}"
async def _flush_pool(self): async def _flush_pool(self):
"""Close all pooled connections (they may be stale after errors).""" """Close all pooled connections (they may be stale after errors)."""
async with self._pool_lock: async with self._pool_lock:
@@ -464,7 +465,7 @@ class DomainFronter:
"""Open one TLS connection and add it to the pool.""" """Open one TLS connection and add it to the pool."""
try: try:
r, w = await asyncio.wait_for(self._open(), timeout=5) r, w = await asyncio.wait_for(self._open(), timeout=5)
t = asyncio.get_event_loop().time() t = asyncio.get_running_loop().time()
async with self._pool_lock: async with self._pool_lock:
if len(self._pool) < self._pool_max: if len(self._pool) < self._pool_max:
self._pool.append((r, w, t)) self._pool.append((r, w, t))
@@ -481,7 +482,7 @@ class DomainFronter:
while True: while True:
try: try:
await asyncio.sleep(3) await asyncio.sleep(3)
now = asyncio.get_event_loop().time() now = asyncio.get_running_loop().time()
# Purge expired / dead connections # Purge expired / dead connections
async with self._pool_lock: async with self._pool_lock:
@@ -713,7 +714,7 @@ class DomainFronter:
race where the owning task's `finally` pops the entry between race where the owning task's `finally` pops the entry between
the check and append by a second task. the check and append by a second task.
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
async with self._batch_lock: async with self._batch_lock:
waiters = self._coalesce.get(url) waiters = self._coalesce.get(url)
if waiters is not None: if waiters is not None:
@@ -833,12 +834,12 @@ class DomainFronter:
f"chunk {s}-{e} failed after {max_tries} tries: {last_err}" f"chunk {s}-{e} failed after {max_tries} tries: {last_err}"
) )
t0 = asyncio.get_event_loop().time() t0 = asyncio.get_running_loop().time()
results = await asyncio.gather( results = await asyncio.gather(
*[fetch_range(s, e) for s, e in ranges], *[fetch_range(s, e) for s, e in ranges],
return_exceptions=True, return_exceptions=True,
) )
elapsed = asyncio.get_event_loop().time() - t0 elapsed = asyncio.get_running_loop().time() - t0
# Assemble full body # Assemble full body
parts = [resp_body] parts = [resp_body]
@@ -870,38 +871,38 @@ class DomainFronter:
chunk_size: int = 256 * 1024, chunk_size: int = 256 * 1024,
max_parallel: int = 16) -> bytes: max_parallel: int = 16) -> bytes:
"""Stream large file download to client as chunks arrive. """Stream large file download to client as chunks arrive.
Downloads file in parallel chunks and streams to client immediately, Downloads file in parallel chunks and streams to client immediately,
avoiding memory buildup and timeout issues for large files. avoiding memory buildup and timeout issues for large files.
Args: Args:
min_size: Minimum file size to enable chunking (0 = no minimum) min_size: Minimum file size to enable chunking (0 = no minimum)
chunk_size: Size of each chunk in bytes (default 256KB) chunk_size: Size of each chunk in bytes (default 256KB)
max_parallel: Maximum parallel chunk downloads (default 16) max_parallel: Maximum parallel chunk downloads (default 16)
""" """
if method != "GET" or body: if method != "GET" or body:
return await self.relay(method, url, headers, body) return await self.relay(method, url, headers, body)
# Probe: first chunk with Range header # Probe: first chunk with Range header
range_headers = dict(headers) if headers else {} range_headers = dict(headers) if headers else {}
range_headers["Range"] = f"bytes=0-{chunk_size - 1}" range_headers["Range"] = f"bytes=0-{chunk_size - 1}"
first_resp = await self.relay("GET", url, range_headers, b"") first_resp = await self.relay("GET", url, range_headers, b"")
status, resp_hdrs, resp_body = self._split_raw_response(first_resp) status, resp_hdrs, resp_body = self._split_raw_response(first_resp)
# No range support → return single response # No range support → return single response
if status != 206: if status != 206:
return first_resp return first_resp
# Parse total size from Content-Range # Parse total size from Content-Range
content_range = resp_hdrs.get("content-range", "") content_range = resp_hdrs.get("content-range", "")
m = re.search(r"/(\d+)", content_range) m = re.search(r"/(\d+)", content_range)
if not m: if not m:
return self._rewrite_206_to_200(first_resp) return self._rewrite_206_to_200(first_resp)
total_size = int(m.group(1)) total_size = int(m.group(1))
# Check minimum size threshold (if configured) # Check minimum size threshold (if configured)
if min_size > 0 and total_size < min_size: if min_size > 0 and total_size < min_size:
log.debug( log.debug(
@@ -909,11 +910,11 @@ class DomainFronter:
total_size, min_size total_size, min_size
) )
return self._rewrite_206_to_200(first_resp) return self._rewrite_206_to_200(first_resp)
# Small file (less than one chunk) → return immediately # Small file (less than one chunk) → return immediately
if total_size <= chunk_size or len(resp_body) >= total_size: if total_size <= chunk_size or len(resp_body) >= total_size:
return self._rewrite_206_to_200(first_resp) return self._rewrite_206_to_200(first_resp)
# Build response header # Build response header
response_header = "HTTP/1.1 200 OK\r\n" response_header = "HTTP/1.1 200 OK\r\n"
skip = {"transfer-encoding", "connection", "keep-alive", skip = {"transfer-encoding", "connection", "keep-alive",
@@ -922,15 +923,15 @@ class DomainFronter:
if k.lower() not in skip: if k.lower() not in skip:
response_header += f"{k}: {v}\r\n" response_header += f"{k}: {v}\r\n"
response_header += f"Content-Length: {total_size}\r\n\r\n" response_header += f"Content-Length: {total_size}\r\n\r\n"
# Send header + first chunk immediately # Send header + first chunk immediately
writer.write(response_header.encode() + resp_body) writer.write(response_header.encode() + resp_body)
await writer.drain() await writer.drain()
log.info("Streaming download: %d bytes, %d chunks of %d KB", log.info("Streaming download: %d bytes, %d chunks of %d KB",
total_size, (total_size + chunk_size - 1) // chunk_size, total_size, (total_size + chunk_size - 1) // chunk_size,
chunk_size // 1024) chunk_size // 1024)
# Calculate remaining ranges # Calculate remaining ranges
ranges = [] ranges = []
start = len(resp_body) start = len(resp_body)
@@ -938,21 +939,21 @@ class DomainFronter:
end = min(start + chunk_size - 1, total_size - 1) end = min(start + chunk_size - 1, total_size - 1)
ranges.append((start, end)) ranges.append((start, end))
start = end + 1 start = end + 1
# Download and stream chunks in order # Download and stream chunks in order
sem = asyncio.Semaphore(max_parallel) sem = asyncio.Semaphore(max_parallel)
chunk_buffer = {} # idx -> chunk_data chunk_buffer = {} # idx -> chunk_data
next_to_stream = 0 next_to_stream = 0
stream_event = asyncio.Event() stream_event = asyncio.Event()
all_downloaded = asyncio.Event() all_downloaded = asyncio.Event()
async def fetch_chunk(idx: int, s: int, e: int, max_retries: int = 5): async def fetch_chunk(idx: int, s: int, e: int, max_retries: int = 5):
"""Download chunk with retry and timeout.""" """Download chunk with retry and timeout."""
async with sem: async with sem:
rh = dict(headers) if headers else {} rh = dict(headers) if headers else {}
rh["Range"] = f"bytes={s}-{e}" rh["Range"] = f"bytes={s}-{e}"
expected_size = e - s + 1 expected_size = e - s + 1
for attempt in range(max_retries): for attempt in range(max_retries):
try: try:
# Low timeout for small chunks (256KB should download quickly) # Low timeout for small chunks (256KB should download quickly)
@@ -961,7 +962,7 @@ class DomainFronter:
timeout=15 timeout=15
) )
_, _, chunk_body = self._split_raw_response(raw) _, _, chunk_body = self._split_raw_response(raw)
# Verify chunk size # Verify chunk size
if len(chunk_body) != expected_size: if len(chunk_body) != expected_size:
log.warning( log.warning(
@@ -971,13 +972,13 @@ class DomainFronter:
if attempt < max_retries - 1: if attempt < max_retries - 1:
await asyncio.sleep(0.5 * (attempt + 1)) await asyncio.sleep(0.5 * (attempt + 1))
continue continue
# Store chunk in buffer # Store chunk in buffer
chunk_buffer[idx] = chunk_body chunk_buffer[idx] = chunk_body
stream_event.set() # Signal streamer stream_event.set() # Signal streamer
log.debug("Downloaded chunk %d/%d", idx + 1, len(ranges)) log.debug("Downloaded chunk %d/%d", idx + 1, len(ranges))
return return
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning( log.warning(
"Chunk %d timeout (retry %d/%d)", "Chunk %d timeout (retry %d/%d)",
@@ -992,10 +993,10 @@ class DomainFronter:
) )
if attempt < max_retries - 1: if attempt < max_retries - 1:
await asyncio.sleep(0.5 * (attempt + 1)) await asyncio.sleep(0.5 * (attempt + 1))
# All retries failed # All retries failed
raise Exception(f"Chunk {idx} failed after {max_retries} retries") raise Exception(f"Chunk {idx} failed after {max_retries} retries")
async def stream_chunks(): async def stream_chunks():
"""Stream chunks to client in sequential order.""" """Stream chunks to client in sequential order."""
nonlocal next_to_stream nonlocal next_to_stream
@@ -1008,25 +1009,25 @@ class DomainFronter:
raise Exception(f"Chunk {next_to_stream} never arrived") raise Exception(f"Chunk {next_to_stream} never arrived")
stream_event.clear() stream_event.clear()
await asyncio.wait_for(stream_event.wait(), timeout=30) await asyncio.wait_for(stream_event.wait(), timeout=30)
# Stream the chunk # Stream the chunk
chunk_data = chunk_buffer.pop(next_to_stream) chunk_data = chunk_buffer.pop(next_to_stream)
writer.write(chunk_data) writer.write(chunk_data)
await writer.drain() await writer.drain()
log.debug("Streamed chunk %d/%d", next_to_stream + 1, len(ranges)) log.debug("Streamed chunk %d/%d", next_to_stream + 1, len(ranges))
next_to_stream += 1 next_to_stream += 1
except Exception as e: except Exception as e:
log.error("Streaming error: %s", e) log.error("Streaming error: %s", e)
raise raise
# Start downloads and streaming concurrently # Start downloads and streaming concurrently
download_tasks = [ download_tasks = [
asyncio.create_task(fetch_chunk(i, s, e)) asyncio.create_task(fetch_chunk(i, s, e))
for i, (s, e) in enumerate(ranges) for i, (s, e) in enumerate(ranges)
] ]
stream_task = asyncio.create_task(stream_chunks()) stream_task = asyncio.create_task(stream_chunks())
# Wait for downloads to complete # Wait for downloads to complete
try: try:
await asyncio.gather(*download_tasks) await asyncio.gather(*download_tasks)
@@ -1041,7 +1042,7 @@ class DomainFronter:
if not stream_task.done(): if not stream_task.done():
stream_task.cancel() stream_task.cancel()
raise raise
# Return empty bytes since we already streamed everything # Return empty bytes since we already streamed everything
return b"" return b""
@@ -1144,7 +1145,7 @@ class DomainFronter:
if not self._batch_enabled: if not self._batch_enabled:
return await self._relay_with_retry(payload) return await self._relay_with_retry(payload)
future = asyncio.get_event_loop().create_future() future = asyncio.get_running_loop().create_future()
async with self._batch_lock: async with self._batch_lock:
self._batch_pending.append((payload, future)) self._batch_pending.append((payload, future))
+1 -1
View File
@@ -139,7 +139,7 @@ class H2Transport:
try: try:
await asyncio.wait_for( await asyncio.wait_for(
asyncio.get_event_loop().sock_connect( asyncio.get_running_loop().sock_connect(
raw, (self.connect_host, 443) raw, (self.connect_host, 443)
), ),
timeout=15, timeout=15,
+7 -7
View File
@@ -161,7 +161,7 @@ class ProxyServer:
self.socks_port = config.get("socks5_port", 1080) self.socks_port = config.get("socks5_port", 1080)
self.fronter = DomainFronter(config) self.fronter = DomainFronter(config)
self.mitm = None self.mitm = None
# Chunked download settings (configurable) # Chunked download settings (configurable)
exts = config.get("chunked_download_extensions", list(LARGE_FILE_EXTS)) exts = config.get("chunked_download_extensions", list(LARGE_FILE_EXTS))
self._chunked_extensions = frozenset(exts) self._chunked_extensions = frozenset(exts)
@@ -853,7 +853,7 @@ class ProxyServer:
# Step 1: MITM — accept TLS from the browser # Step 1: MITM — accept TLS from the browser
ssl_ctx_server = self.mitm.get_server_context(host) ssl_ctx_server = self.mitm.get_server_context(host)
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
transport = writer.transport transport = writer.transport
protocol = transport.get_protocol() protocol = transport.get_protocol()
try: try:
@@ -925,7 +925,7 @@ class ProxyServer:
ssl_ctx = self.mitm.get_server_context(host) ssl_ctx = self.mitm.get_server_context(host)
# Upgrade the existing connection to TLS (we are the server) # Upgrade the existing connection to TLS (we are the server)
loop = asyncio.get_event_loop() loop = asyncio.get_running_loop()
transport = writer.transport transport = writer.transport
protocol = transport.get_protocol() protocol = transport.get_protocol()
@@ -1012,11 +1012,11 @@ class ProxyServer:
if b":" in raw_line: if b":" in raw_line:
k, v = raw_line.decode(errors="replace").split(":", 1) k, v = raw_line.decode(errors="replace").split(":", 1)
headers[k.strip()] = v.strip() headers[k.strip()] = v.strip()
# Shortening the length of X API URLs to prevent relay errors. # Shortening the length of X API URLs to prevent relay errors.
if host == "x.com" and re.match(r"/i/api/graphql/[^/]+/[^?]+\?variables=", path): if host == "x.com" and re.match(r"/i/api/graphql/[^/]+/[^?]+\?variables=", path):
path = path.split("&")[0] path = path.split("&")[0]
# MITM traffic arrives as origin-form paths; SOCKS/plain HTTP can # MITM traffic arrives as origin-form paths; SOCKS/plain HTTP can
# also send absolute-form requests. Normalize both to full URLs. # also send absolute-form requests. Normalize both to full URLs.
if path.startswith("http://") or path.startswith("https://"): if path.startswith("http://") or path.startswith("https://"):
@@ -1196,13 +1196,13 @@ class ProxyServer:
def _is_likely_download(self, url: str, headers: dict) -> bool: def _is_likely_download(self, url: str, headers: dict) -> bool:
"""Heuristic: is this URL likely a large file download? """Heuristic: is this URL likely a large file download?
If ".*" is in chunked_download_extensions, bypasses extension check If ".*" is in chunked_download_extensions, bypasses extension check
and returns True for all URLs (enables chunking for any file). and returns True for all URLs (enables chunking for any file).
""" """
if self._chunked_bypass_check: if self._chunked_bypass_check:
return True return True
path = url.split("?")[0].lower() path = url.split("?")[0].lower()
for ext in self._chunked_extensions: for ext in self._chunked_extensions:
if path.endswith(ext.lower()): if path.endswith(ext.lower()):