From 630a9906ceda3bcda510a4dc9d7487694e17e407 Mon Sep 17 00:00:00 2001 From: rlaphoenix Date: Thu, 15 Feb 2024 16:07:42 +0000 Subject: [PATCH] Rework the Aria2c Downloader - Downloads are now multithreaded directly in the downloader. - Now reuses connections instead of having to close and reopen connections for every single download. - Progress updates are now yielded back to the caller instead of drilling down a progress callable. - Instead of parsing download progress information in a very hacky way from the stdout stream, use aria2's RPC interface. - Added a new utility get_free_port which is needed to choose aria2's RPC port as I do not want to use the default port in case the user is already using this port for another tool or reason. Also, to try mitigate port scanning attacks that target aria2 RPC ports. - The config entry `aria2c.max_concurrent_downloads` is now actually used by aria2c when downloading. - The `--max-concurrent-downloads` option and config value now defaults to `min(32,(cpu_count+4))` (usually around 16 for above average systems) instead of 5. - Automated pproxy proxy rerouter is made via subprocess instead of trying to re-do what the pproxy entry point does for us, less code, less trouble, and was ultimately easier to implement. --- CONFIG.md | 7 +- devine/core/downloaders/__init__.py | 4 +- devine/core/downloaders/aria2c.py | 304 +++++++++++++++++++--------- devine/core/utilities.py | 43 ++-- 4 files changed, 225 insertions(+), 133 deletions(-) diff --git a/CONFIG.md b/CONFIG.md index 318f9eb..1cc0a6b 100644 --- a/CONFIG.md +++ b/CONFIG.md @@ -11,13 +11,12 @@ which does not keep comments. ## aria2c (dict) - `max_concurrent_downloads` - Maximum number of parallel downloads. Default: `5` - Note: Currently unused as downloads are multi-threaded by Devine rather than Aria2c. - Devine internally has a constant set value of 16 for it's parallel downloads. + Maximum number of parallel downloads. Default: `min(32,(cpu_count+4))` + Note: Overrides the `max_workers` parameter of the aria2(c) downloader function. - `max_connection_per_server` Maximum number of connections to one server for each download. Default: `1` - `split` - Split a file into N chunks and download each chunk on it's own connection. Default: `5` + Split a file into N chunks and download each chunk on its own connection. Default: `5` - `file_allocation` Specify file allocation method. Default: `"prealloc"` diff --git a/devine/core/downloaders/__init__.py b/devine/core/downloaders/__init__.py index 6528ce8..7dfb36c 100644 --- a/devine/core/downloaders/__init__.py +++ b/devine/core/downloaders/__init__.py @@ -1,12 +1,10 @@ -import asyncio - from ..config import config from .aria2c import aria2c from .curl_impersonate import curl_impersonate from .requests import requests downloader = { - "aria2c": lambda *args, **kwargs: asyncio.run(aria2c(*args, **kwargs)), + "aria2c": aria2c, "curl_impersonate": curl_impersonate, "requests": requests }[config.downloader] diff --git a/devine/core/downloaders/aria2c.py b/devine/core/downloaders/aria2c.py index 15305fa..ba30b77 100644 --- a/devine/core/downloaders/aria2c.py +++ b/devine/core/downloaders/aria2c.py @@ -1,84 +1,144 @@ -import asyncio +import os import subprocess import textwrap +import time from functools import partial from http.cookiejar import CookieJar from pathlib import Path -from typing import MutableMapping, Optional, Union +from typing import Any, Callable, Generator, MutableMapping, Optional, Union +from urllib.parse import urlparse import requests +from Crypto.Random import get_random_bytes +from requests import Session from requests.cookies import RequestsCookieJar, cookiejar_from_dict, get_cookie_header +from rich import filesize from rich.text import Text from devine.core.config import config from devine.core.console import console -from devine.core.utilities import get_binary_path, start_pproxy +from devine.core.utilities import get_binary_path, get_free_port -async def aria2c( - uri: Union[str, list[str]], - out: Path, - headers: Optional[dict] = None, +def rpc(caller: Callable, secret: str, method: str, *params: Any) -> dict[str, Any]: + """Make a call to Aria2's JSON-RPC API.""" + rpc_res = caller( + json={ + "jsonrpc": "2.0", + "id": get_random_bytes(16).hex(), + "method": method, + "params": [f"token:{secret}", *params] + } + ).json() + if rpc_res.get("code"): + # wrap to console width - padding - '[Aria2c]: ' + error_pretty = "\n ".join(textwrap.wrap( + f"RPC Error: {rpc_res['message']} ({rpc_res['code']})".strip(), + width=console.width - 20, + initial_indent="" + )) + console.log(Text.from_ansi("\n[Aria2c]: " + error_pretty)) + return rpc_res["result"] + + +def download( + urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]], + output_dir: Path, + filename: str, + headers: Optional[MutableMapping[str, Union[str, bytes]]] = None, cookies: Optional[Union[MutableMapping[str, str], RequestsCookieJar]] = None, proxy: Optional[str] = None, - silent: bool = False, - segmented: bool = False, - progress: Optional[partial] = None, - *args: str -) -> int: - """ - Download files using Aria2(c). - https://aria2.github.io + max_workers: Optional[int] = None +) -> Generator[dict[str, Any], None, None]: + if not urls: + raise ValueError("urls must be provided and not empty") + elif not isinstance(urls, (str, dict, list)): + raise TypeError(f"Expected urls to be {str} or {dict} or a list of one of them, not {type(urls)}") - If multiple URLs are provided they will be downloaded in the provided order - to the output directory. They will not be merged together. - """ - if not isinstance(uri, list): - uri = [uri] + if not output_dir: + raise ValueError("output_dir must be provided") + elif not isinstance(output_dir, Path): + raise TypeError(f"Expected output_dir to be {Path}, not {type(output_dir)}") - if cookies and not isinstance(cookies, CookieJar): - cookies = cookiejar_from_dict(cookies) + if not filename: + raise ValueError("filename must be provided") + elif not isinstance(filename, str): + raise TypeError(f"Expected filename to be {str}, not {type(filename)}") + + if not isinstance(headers, (MutableMapping, type(None))): + raise TypeError(f"Expected headers to be {MutableMapping}, not {type(headers)}") + + if not isinstance(cookies, (MutableMapping, RequestsCookieJar, type(None))): + raise TypeError(f"Expected cookies to be {MutableMapping} or {RequestsCookieJar}, not {type(cookies)}") + + if not isinstance(proxy, (str, type(None))): + raise TypeError(f"Expected proxy to be {str}, not {type(proxy)}") + + if not max_workers: + max_workers = min(32, (os.cpu_count() or 1) + 4) + elif not isinstance(max_workers, int): + raise TypeError(f"Expected max_workers to be {int}, not {type(max_workers)}") + + if not isinstance(urls, list): + urls = [urls] executable = get_binary_path("aria2c", "aria2") if not executable: raise EnvironmentError("Aria2c executable not found...") - if proxy and proxy.lower().split(":")[0] != "http": - # HTTPS proxies are not supported by aria2(c). - # Proxy the proxy via pproxy to access it as an HTTP proxy. - async with start_pproxy(proxy) as pproxy_: - return await aria2c(uri, out, headers, cookies, pproxy_, silent, segmented, progress, *args) + if proxy and not proxy.lower().startswith("http://"): + raise ValueError("Only HTTP proxies are supported by aria2(c)") + + if cookies and not isinstance(cookies, CookieJar): + cookies = cookiejar_from_dict(cookies) - multiple_urls = len(uri) > 1 url_files = [] - for i, url in enumerate(uri): - url_text = url - if multiple_urls: - url_text += f"\n\tdir={out}" - url_text += f"\n\tout={i:08}.mp4" + for i, url in enumerate(urls): + if isinstance(url, str): + url_data = { + "url": url + } else: - url_text += f"\n\tdir={out.parent}" - url_text += f"\n\tout={out.name}" + url_data: dict[str, Any] = url + url_filename = filename.format( + i=i, + ext=Path(url_data["url"]).suffix + ) + url_text = url_data["url"] + url_text += f"\n\tdir={output_dir}" + url_text += f"\n\tout={url_filename}" if cookies: - mock_request = requests.Request(url=url) + mock_request = requests.Request(url=url_data["url"]) cookie_header = get_cookie_header(cookies, mock_request) if cookie_header: url_text += f"\n\theader=Cookie: {cookie_header}" + for key, value in url_data.items(): + if key == "url": + continue + if key == "headers": + for header_name, header_value in value.items(): + url_text += f"\n\theader={header_name}: {header_value}" + else: + url_text += f"\n\t{key}={value}" url_files.append(url_text) url_file = "\n".join(url_files) - max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", 5)) + rpc_port = get_free_port() + rpc_secret = get_random_bytes(16).hex() + rpc_uri = f"http://127.0.0.1:{rpc_port}/jsonrpc" + rpc_session = Session() + + max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers)) max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1)) split = int(config.aria2c.get("split", 5)) file_allocation = config.aria2c.get("file_allocation", "prealloc") - if segmented: + if len(urls) > 1: split = 1 file_allocation = "none" arguments = [ # [Basic Options] "--input-file", "-", - "--out", out.name, "--all-proxy", proxy or "", "--continue=true", # [Connection Options] @@ -92,11 +152,13 @@ async def aria2c( "--allow-overwrite=true", "--auto-file-renaming=false", "--console-log-level=warn", - f"--download-result={'default' if progress else 'hide'}", + "--download-result=default", f"--file-allocation={file_allocation}", "--summary-interval=0", - # [Extra Options] - *args + # [RPC Options] + "--enable-rpc=true", + f"--rpc-listen-port={rpc_port}", + f"--rpc-secret={rpc_secret}" ] for header, value in (headers or {}).items(): @@ -115,66 +177,44 @@ async def aria2c( arguments.extend(["--header", f"{header}: {value}"]) try: - p = await asyncio.create_subprocess_exec( - executable, - *arguments, + p = subprocess.Popen( + [ + executable, + *arguments + ], stdin=subprocess.PIPE, - stdout=subprocess.PIPE + stdout=subprocess.DEVNULL ) p.stdin.write(url_file.encode()) - await p.stdin.drain() p.stdin.close() - if p.stdout: - is_dl_summary = False - log_buffer = "" - while True: - try: - chunk = await p.stdout.readuntil(b"\r") - except asyncio.IncompleteReadError as e: - chunk = e.partial - if not chunk: + while p.poll() is None: + global_stats = rpc( + caller=partial(rpc_session.post, url=rpc_uri), + secret=rpc_secret, + method="aria2.getGlobalStat" + ) + if global_stats: + active = int(global_stats["numActive"]) + waiting = int(global_stats["numWaiting"]) + stopped = int(global_stats["numStopped"]) + total = active + waiting + stopped + yield dict( + total=total, + completed=stopped, + downloaded=f"{filesize.decimal(int(global_stats['downloadSpeed']))}/s" + ) + if total == stopped: + rpc( + caller=partial(rpc_session.post, url=rpc_uri), + secret=rpc_secret, + method="aria2.shutdown" + ) break - for line in chunk.decode().strip().splitlines(): - if not line: - continue - if line.startswith("Download Results"): - # we know it's 100% downloaded, but let's use the avg dl speed value - is_dl_summary = True - elif line.startswith("[") and line.endswith("]"): - if progress and "%" in line: - # id, dledMiB/totalMiB(x%), CN:xx, DL:xxMiB, ETA:Xs - # eta may not always be available - data_parts = line[1:-1].split() - perc_parts = data_parts[1].split("(") - if len(perc_parts) == 2: - # might otherwise be e.g., 0B/0B, with no % symbol provided - progress( - total=100, - completed=int(perc_parts[1][:-2]), - downloaded=f"{data_parts[3].split(':')[1]}/s" - ) - elif is_dl_summary and "OK" in line and "|" in line: - gid, status, avg_speed, path_or_uri = line.split("|") - progress(total=100, completed=100, downloaded=avg_speed.strip()) - elif not is_dl_summary: - if "aria2 will resume download if the transfer is restarted" in line: - continue - if "If there are any errors, then see the log file" in line: - continue - log_buffer += f"{line.strip()}\n" + time.sleep(1) - if log_buffer and not silent: - # wrap to console width - padding - '[Aria2c]: ' - log_buffer = "\n ".join(textwrap.wrap( - log_buffer.rstrip(), - width=console.width - 20, - initial_indent="" - )) - console.log(Text.from_ansi("\n[Aria2c]: " + log_buffer)) - - await p.wait() + p.wait() if p.returncode != 0: raise subprocess.CalledProcessError(p.returncode, arguments) @@ -188,7 +228,81 @@ async def aria2c( raise KeyboardInterrupt() raise - return p.returncode + +def aria2c( + urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]], + output_dir: Path, + filename: str, + headers: Optional[MutableMapping[str, Union[str, bytes]]] = None, + cookies: Optional[Union[MutableMapping[str, str], RequestsCookieJar]] = None, + proxy: Optional[str] = None, + max_workers: Optional[int] = None +) -> Generator[dict[str, Any], None, None]: + """ + Download files using Aria2(c). + https://aria2.github.io + + Yields the following download status updates while chunks are downloading: + + - {total: 100} (100% download total) + - {completed: 1} (1% download progress out of 100%) + - {downloaded: "10.1 MB/s"} (currently downloading at a rate of 10.1 MB/s) + + The data is in the same format accepted by rich's progress.update() function. + + Parameters: + urls: Web URL(s) to file(s) to download. You can use a dictionary with the key + "url" for the URI, and other keys for extra arguments to use per-URL. + output_dir: The folder to save the file into. If the save path's directory does + not exist then it will be made automatically. + filename: The filename or filename template to use for each file. The variables + you can use are `i` for the URL index and `ext` for the URL extension. + headers: A mapping of HTTP Header Key/Values to use for all downloads. + cookies: A mapping of Cookie Key/Values or a Cookie Jar to use for all downloads. + proxy: An optional proxy URI to route connections through for all downloads. + max_workers: The maximum amount of threads to use for downloads. Defaults to + min(32,(cpu_count+4)). Use for the --max-concurrent-downloads option. + """ + if proxy and not proxy.lower().startswith("http://"): + # Only HTTP proxies are supported by aria2(c) + proxy = urlparse(proxy) + + port = get_free_port() + username, password = get_random_bytes(8).hex(), get_random_bytes(8).hex() + local_proxy = f"http://{username}:{password}@localhost:{port}" + + scheme = { + "https": "http+ssl", + "socks5h": "socks" + }.get(proxy.scheme, proxy.scheme) + + remote_server = f"{scheme}://{proxy.hostname}" + if proxy.port: + remote_server += f":{proxy.port}" + if proxy.username or proxy.password: + remote_server += "#" + if proxy.username: + remote_server += proxy.username + if proxy.password: + remote_server += f":{proxy.password}" + + p = subprocess.Popen( + [ + "pproxy", + "-l", f"http://:{port}#{username}:{password}", + "-r", remote_server + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL + ) + + try: + yield from download(urls, output_dir, filename, headers, cookies, local_proxy, max_workers) + finally: + p.kill() + p.wait() + return + yield from download(urls, output_dir, filename, headers, cookies, proxy, max_workers) __all__ = ("aria2c",) diff --git a/devine/core/utilities.py b/devine/core/utilities.py index e35a2b4..5232a7a 100644 --- a/devine/core/utilities.py +++ b/devine/core/utilities.py @@ -3,6 +3,7 @@ import contextlib import importlib.util import re import shutil +import socket import sys import time import unicodedata @@ -10,11 +11,9 @@ from collections import defaultdict from datetime import datetime from pathlib import Path from types import ModuleType -from typing import AsyncIterator, Optional, Sequence, Union -from urllib.parse import urlparse +from typing import Optional, Sequence, Union import chardet -import pproxy import requests from construct import ValidationError from langcodes import Language, closest_match @@ -244,35 +243,17 @@ def try_ensure_utf8(data: bytes) -> bytes: return data -@contextlib.asynccontextmanager -async def start_pproxy(proxy: str) -> AsyncIterator[str]: - proxy = urlparse(proxy) +def get_free_port() -> int: + """ + Get an available port to use between a-b (inclusive). - scheme = { - "https": "http+ssl", - "socks5h": "socks" - }.get(proxy.scheme, proxy.scheme) - - remote_server = f"{scheme}://{proxy.hostname}" - if proxy.port: - remote_server += f":{proxy.port}" - if proxy.username or proxy.password: - remote_server += "#" - if proxy.username: - remote_server += proxy.username - if proxy.password: - remote_server += f":{proxy.password}" - - server = pproxy.Server("http://localhost:0") # random port - remote = pproxy.Connection(remote_server) - handler = await server.start_server({"rserver": [remote]}) - - try: - port = handler.sockets[0].getsockname()[1] - yield f"http://localhost:{port}" - finally: - handler.close() - await handler.wait_closed() + The port is freed as soon as this has returned, therefore, it + is possible for the port to be taken before you try to use it. + """ + with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] class FPS(ast.NodeVisitor):