diff --git a/CONFIG.md b/CONFIG.md index 318f9eb..1cc0a6b 100644 --- a/CONFIG.md +++ b/CONFIG.md @@ -11,13 +11,12 @@ which does not keep comments. ## aria2c (dict) - `max_concurrent_downloads` - Maximum number of parallel downloads. Default: `5` - Note: Currently unused as downloads are multi-threaded by Devine rather than Aria2c. - Devine internally has a constant set value of 16 for it's parallel downloads. + Maximum number of parallel downloads. Default: `min(32,(cpu_count+4))` + Note: Overrides the `max_workers` parameter of the aria2(c) downloader function. - `max_connection_per_server` Maximum number of connections to one server for each download. Default: `1` - `split` - Split a file into N chunks and download each chunk on it's own connection. Default: `5` + Split a file into N chunks and download each chunk on its own connection. Default: `5` - `file_allocation` Specify file allocation method. Default: `"prealloc"` diff --git a/devine/core/downloaders/__init__.py b/devine/core/downloaders/__init__.py index 6528ce8..7dfb36c 100644 --- a/devine/core/downloaders/__init__.py +++ b/devine/core/downloaders/__init__.py @@ -1,12 +1,10 @@ -import asyncio - from ..config import config from .aria2c import aria2c from .curl_impersonate import curl_impersonate from .requests import requests downloader = { - "aria2c": lambda *args, **kwargs: asyncio.run(aria2c(*args, **kwargs)), + "aria2c": aria2c, "curl_impersonate": curl_impersonate, "requests": requests }[config.downloader] diff --git a/devine/core/downloaders/aria2c.py b/devine/core/downloaders/aria2c.py index 15305fa..ba30b77 100644 --- a/devine/core/downloaders/aria2c.py +++ b/devine/core/downloaders/aria2c.py @@ -1,84 +1,144 @@ -import asyncio +import os import subprocess import textwrap +import time from functools import partial from http.cookiejar import CookieJar from pathlib import Path -from typing import MutableMapping, Optional, Union +from typing import Any, Callable, Generator, MutableMapping, Optional, Union +from urllib.parse import urlparse import requests +from Crypto.Random import get_random_bytes +from requests import Session from requests.cookies import RequestsCookieJar, cookiejar_from_dict, get_cookie_header +from rich import filesize from rich.text import Text from devine.core.config import config from devine.core.console import console -from devine.core.utilities import get_binary_path, start_pproxy +from devine.core.utilities import get_binary_path, get_free_port -async def aria2c( - uri: Union[str, list[str]], - out: Path, - headers: Optional[dict] = None, +def rpc(caller: Callable, secret: str, method: str, *params: Any) -> dict[str, Any]: + """Make a call to Aria2's JSON-RPC API.""" + rpc_res = caller( + json={ + "jsonrpc": "2.0", + "id": get_random_bytes(16).hex(), + "method": method, + "params": [f"token:{secret}", *params] + } + ).json() + if rpc_res.get("code"): + # wrap to console width - padding - '[Aria2c]: ' + error_pretty = "\n ".join(textwrap.wrap( + f"RPC Error: {rpc_res['message']} ({rpc_res['code']})".strip(), + width=console.width - 20, + initial_indent="" + )) + console.log(Text.from_ansi("\n[Aria2c]: " + error_pretty)) + return rpc_res["result"] + + +def download( + urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]], + output_dir: Path, + filename: str, + headers: Optional[MutableMapping[str, Union[str, bytes]]] = None, cookies: Optional[Union[MutableMapping[str, str], RequestsCookieJar]] = None, proxy: Optional[str] = None, - silent: bool = False, - segmented: bool = False, - progress: Optional[partial] = None, - *args: str -) -> int: - """ - Download files using Aria2(c). - https://aria2.github.io + max_workers: Optional[int] = None +) -> Generator[dict[str, Any], None, None]: + if not urls: + raise ValueError("urls must be provided and not empty") + elif not isinstance(urls, (str, dict, list)): + raise TypeError(f"Expected urls to be {str} or {dict} or a list of one of them, not {type(urls)}") - If multiple URLs are provided they will be downloaded in the provided order - to the output directory. They will not be merged together. - """ - if not isinstance(uri, list): - uri = [uri] + if not output_dir: + raise ValueError("output_dir must be provided") + elif not isinstance(output_dir, Path): + raise TypeError(f"Expected output_dir to be {Path}, not {type(output_dir)}") - if cookies and not isinstance(cookies, CookieJar): - cookies = cookiejar_from_dict(cookies) + if not filename: + raise ValueError("filename must be provided") + elif not isinstance(filename, str): + raise TypeError(f"Expected filename to be {str}, not {type(filename)}") + + if not isinstance(headers, (MutableMapping, type(None))): + raise TypeError(f"Expected headers to be {MutableMapping}, not {type(headers)}") + + if not isinstance(cookies, (MutableMapping, RequestsCookieJar, type(None))): + raise TypeError(f"Expected cookies to be {MutableMapping} or {RequestsCookieJar}, not {type(cookies)}") + + if not isinstance(proxy, (str, type(None))): + raise TypeError(f"Expected proxy to be {str}, not {type(proxy)}") + + if not max_workers: + max_workers = min(32, (os.cpu_count() or 1) + 4) + elif not isinstance(max_workers, int): + raise TypeError(f"Expected max_workers to be {int}, not {type(max_workers)}") + + if not isinstance(urls, list): + urls = [urls] executable = get_binary_path("aria2c", "aria2") if not executable: raise EnvironmentError("Aria2c executable not found...") - if proxy and proxy.lower().split(":")[0] != "http": - # HTTPS proxies are not supported by aria2(c). - # Proxy the proxy via pproxy to access it as an HTTP proxy. - async with start_pproxy(proxy) as pproxy_: - return await aria2c(uri, out, headers, cookies, pproxy_, silent, segmented, progress, *args) + if proxy and not proxy.lower().startswith("http://"): + raise ValueError("Only HTTP proxies are supported by aria2(c)") + + if cookies and not isinstance(cookies, CookieJar): + cookies = cookiejar_from_dict(cookies) - multiple_urls = len(uri) > 1 url_files = [] - for i, url in enumerate(uri): - url_text = url - if multiple_urls: - url_text += f"\n\tdir={out}" - url_text += f"\n\tout={i:08}.mp4" + for i, url in enumerate(urls): + if isinstance(url, str): + url_data = { + "url": url + } else: - url_text += f"\n\tdir={out.parent}" - url_text += f"\n\tout={out.name}" + url_data: dict[str, Any] = url + url_filename = filename.format( + i=i, + ext=Path(url_data["url"]).suffix + ) + url_text = url_data["url"] + url_text += f"\n\tdir={output_dir}" + url_text += f"\n\tout={url_filename}" if cookies: - mock_request = requests.Request(url=url) + mock_request = requests.Request(url=url_data["url"]) cookie_header = get_cookie_header(cookies, mock_request) if cookie_header: url_text += f"\n\theader=Cookie: {cookie_header}" + for key, value in url_data.items(): + if key == "url": + continue + if key == "headers": + for header_name, header_value in value.items(): + url_text += f"\n\theader={header_name}: {header_value}" + else: + url_text += f"\n\t{key}={value}" url_files.append(url_text) url_file = "\n".join(url_files) - max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", 5)) + rpc_port = get_free_port() + rpc_secret = get_random_bytes(16).hex() + rpc_uri = f"http://127.0.0.1:{rpc_port}/jsonrpc" + rpc_session = Session() + + max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers)) max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1)) split = int(config.aria2c.get("split", 5)) file_allocation = config.aria2c.get("file_allocation", "prealloc") - if segmented: + if len(urls) > 1: split = 1 file_allocation = "none" arguments = [ # [Basic Options] "--input-file", "-", - "--out", out.name, "--all-proxy", proxy or "", "--continue=true", # [Connection Options] @@ -92,11 +152,13 @@ async def aria2c( "--allow-overwrite=true", "--auto-file-renaming=false", "--console-log-level=warn", - f"--download-result={'default' if progress else 'hide'}", + "--download-result=default", f"--file-allocation={file_allocation}", "--summary-interval=0", - # [Extra Options] - *args + # [RPC Options] + "--enable-rpc=true", + f"--rpc-listen-port={rpc_port}", + f"--rpc-secret={rpc_secret}" ] for header, value in (headers or {}).items(): @@ -115,66 +177,44 @@ async def aria2c( arguments.extend(["--header", f"{header}: {value}"]) try: - p = await asyncio.create_subprocess_exec( - executable, - *arguments, + p = subprocess.Popen( + [ + executable, + *arguments + ], stdin=subprocess.PIPE, - stdout=subprocess.PIPE + stdout=subprocess.DEVNULL ) p.stdin.write(url_file.encode()) - await p.stdin.drain() p.stdin.close() - if p.stdout: - is_dl_summary = False - log_buffer = "" - while True: - try: - chunk = await p.stdout.readuntil(b"\r") - except asyncio.IncompleteReadError as e: - chunk = e.partial - if not chunk: + while p.poll() is None: + global_stats = rpc( + caller=partial(rpc_session.post, url=rpc_uri), + secret=rpc_secret, + method="aria2.getGlobalStat" + ) + if global_stats: + active = int(global_stats["numActive"]) + waiting = int(global_stats["numWaiting"]) + stopped = int(global_stats["numStopped"]) + total = active + waiting + stopped + yield dict( + total=total, + completed=stopped, + downloaded=f"{filesize.decimal(int(global_stats['downloadSpeed']))}/s" + ) + if total == stopped: + rpc( + caller=partial(rpc_session.post, url=rpc_uri), + secret=rpc_secret, + method="aria2.shutdown" + ) break - for line in chunk.decode().strip().splitlines(): - if not line: - continue - if line.startswith("Download Results"): - # we know it's 100% downloaded, but let's use the avg dl speed value - is_dl_summary = True - elif line.startswith("[") and line.endswith("]"): - if progress and "%" in line: - # id, dledMiB/totalMiB(x%), CN:xx, DL:xxMiB, ETA:Xs - # eta may not always be available - data_parts = line[1:-1].split() - perc_parts = data_parts[1].split("(") - if len(perc_parts) == 2: - # might otherwise be e.g., 0B/0B, with no % symbol provided - progress( - total=100, - completed=int(perc_parts[1][:-2]), - downloaded=f"{data_parts[3].split(':')[1]}/s" - ) - elif is_dl_summary and "OK" in line and "|" in line: - gid, status, avg_speed, path_or_uri = line.split("|") - progress(total=100, completed=100, downloaded=avg_speed.strip()) - elif not is_dl_summary: - if "aria2 will resume download if the transfer is restarted" in line: - continue - if "If there are any errors, then see the log file" in line: - continue - log_buffer += f"{line.strip()}\n" + time.sleep(1) - if log_buffer and not silent: - # wrap to console width - padding - '[Aria2c]: ' - log_buffer = "\n ".join(textwrap.wrap( - log_buffer.rstrip(), - width=console.width - 20, - initial_indent="" - )) - console.log(Text.from_ansi("\n[Aria2c]: " + log_buffer)) - - await p.wait() + p.wait() if p.returncode != 0: raise subprocess.CalledProcessError(p.returncode, arguments) @@ -188,7 +228,81 @@ async def aria2c( raise KeyboardInterrupt() raise - return p.returncode + +def aria2c( + urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]], + output_dir: Path, + filename: str, + headers: Optional[MutableMapping[str, Union[str, bytes]]] = None, + cookies: Optional[Union[MutableMapping[str, str], RequestsCookieJar]] = None, + proxy: Optional[str] = None, + max_workers: Optional[int] = None +) -> Generator[dict[str, Any], None, None]: + """ + Download files using Aria2(c). + https://aria2.github.io + + Yields the following download status updates while chunks are downloading: + + - {total: 100} (100% download total) + - {completed: 1} (1% download progress out of 100%) + - {downloaded: "10.1 MB/s"} (currently downloading at a rate of 10.1 MB/s) + + The data is in the same format accepted by rich's progress.update() function. + + Parameters: + urls: Web URL(s) to file(s) to download. You can use a dictionary with the key + "url" for the URI, and other keys for extra arguments to use per-URL. + output_dir: The folder to save the file into. If the save path's directory does + not exist then it will be made automatically. + filename: The filename or filename template to use for each file. The variables + you can use are `i` for the URL index and `ext` for the URL extension. + headers: A mapping of HTTP Header Key/Values to use for all downloads. + cookies: A mapping of Cookie Key/Values or a Cookie Jar to use for all downloads. + proxy: An optional proxy URI to route connections through for all downloads. + max_workers: The maximum amount of threads to use for downloads. Defaults to + min(32,(cpu_count+4)). Use for the --max-concurrent-downloads option. + """ + if proxy and not proxy.lower().startswith("http://"): + # Only HTTP proxies are supported by aria2(c) + proxy = urlparse(proxy) + + port = get_free_port() + username, password = get_random_bytes(8).hex(), get_random_bytes(8).hex() + local_proxy = f"http://{username}:{password}@localhost:{port}" + + scheme = { + "https": "http+ssl", + "socks5h": "socks" + }.get(proxy.scheme, proxy.scheme) + + remote_server = f"{scheme}://{proxy.hostname}" + if proxy.port: + remote_server += f":{proxy.port}" + if proxy.username or proxy.password: + remote_server += "#" + if proxy.username: + remote_server += proxy.username + if proxy.password: + remote_server += f":{proxy.password}" + + p = subprocess.Popen( + [ + "pproxy", + "-l", f"http://:{port}#{username}:{password}", + "-r", remote_server + ], + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL + ) + + try: + yield from download(urls, output_dir, filename, headers, cookies, local_proxy, max_workers) + finally: + p.kill() + p.wait() + return + yield from download(urls, output_dir, filename, headers, cookies, proxy, max_workers) __all__ = ("aria2c",) diff --git a/devine/core/utilities.py b/devine/core/utilities.py index e35a2b4..5232a7a 100644 --- a/devine/core/utilities.py +++ b/devine/core/utilities.py @@ -3,6 +3,7 @@ import contextlib import importlib.util import re import shutil +import socket import sys import time import unicodedata @@ -10,11 +11,9 @@ from collections import defaultdict from datetime import datetime from pathlib import Path from types import ModuleType -from typing import AsyncIterator, Optional, Sequence, Union -from urllib.parse import urlparse +from typing import Optional, Sequence, Union import chardet -import pproxy import requests from construct import ValidationError from langcodes import Language, closest_match @@ -244,35 +243,17 @@ def try_ensure_utf8(data: bytes) -> bytes: return data -@contextlib.asynccontextmanager -async def start_pproxy(proxy: str) -> AsyncIterator[str]: - proxy = urlparse(proxy) +def get_free_port() -> int: + """ + Get an available port to use between a-b (inclusive). - scheme = { - "https": "http+ssl", - "socks5h": "socks" - }.get(proxy.scheme, proxy.scheme) - - remote_server = f"{scheme}://{proxy.hostname}" - if proxy.port: - remote_server += f":{proxy.port}" - if proxy.username or proxy.password: - remote_server += "#" - if proxy.username: - remote_server += proxy.username - if proxy.password: - remote_server += f":{proxy.password}" - - server = pproxy.Server("http://localhost:0") # random port - remote = pproxy.Connection(remote_server) - handler = await server.start_server({"rserver": [remote]}) - - try: - port = handler.sockets[0].getsockname()[1] - yield f"http://localhost:{port}" - finally: - handler.close() - await handler.wait_closed() + The port is freed as soon as this has returned, therefore, it + is possible for the port to be taken before you try to use it. + """ + with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: + s.bind(("", 0)) + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + return s.getsockname()[1] class FPS(ast.NodeVisitor):