Rework the Aria2c Downloader

- Downloads are now multithreaded directly in the downloader.
- Now reuses connections instead of having to close and reopen connections for every single download.
- Progress updates are now yielded back to the caller instead of drilling down a progress callable.
- Instead of parsing download progress information in a very hacky way from the stdout stream, use aria2's RPC interface.
- Added a new utility get_free_port which is needed to choose aria2's RPC port as I do not want to use the default port in case the user is already using this port for another tool or reason. Also, to try mitigate port scanning attacks that target aria2 RPC ports.
- The config entry `aria2c.max_concurrent_downloads` is now actually used by aria2c when downloading.
- The `--max-concurrent-downloads` option and config value now defaults to `min(32,(cpu_count+4))` (usually around 16 for above average systems) instead of 5.
- Automated pproxy proxy rerouter is made via subprocess instead of trying to re-do what the pproxy entry point does for us, less code, less trouble, and was ultimately easier to implement.
This commit is contained in:
rlaphoenix 2024-02-15 16:07:42 +00:00
parent 2b7fc929f6
commit 630a9906ce
4 changed files with 225 additions and 133 deletions

View File

@ -11,13 +11,12 @@ which does not keep comments.
## aria2c (dict) ## aria2c (dict)
- `max_concurrent_downloads` - `max_concurrent_downloads`
Maximum number of parallel downloads. Default: `5` Maximum number of parallel downloads. Default: `min(32,(cpu_count+4))`
Note: Currently unused as downloads are multi-threaded by Devine rather than Aria2c. Note: Overrides the `max_workers` parameter of the aria2(c) downloader function.
Devine internally has a constant set value of 16 for it's parallel downloads.
- `max_connection_per_server` - `max_connection_per_server`
Maximum number of connections to one server for each download. Default: `1` Maximum number of connections to one server for each download. Default: `1`
- `split` - `split`
Split a file into N chunks and download each chunk on it's own connection. Default: `5` Split a file into N chunks and download each chunk on its own connection. Default: `5`
- `file_allocation` - `file_allocation`
Specify file allocation method. Default: `"prealloc"` Specify file allocation method. Default: `"prealloc"`

View File

@ -1,12 +1,10 @@
import asyncio
from ..config import config from ..config import config
from .aria2c import aria2c from .aria2c import aria2c
from .curl_impersonate import curl_impersonate from .curl_impersonate import curl_impersonate
from .requests import requests from .requests import requests
downloader = { downloader = {
"aria2c": lambda *args, **kwargs: asyncio.run(aria2c(*args, **kwargs)), "aria2c": aria2c,
"curl_impersonate": curl_impersonate, "curl_impersonate": curl_impersonate,
"requests": requests "requests": requests
}[config.downloader] }[config.downloader]

View File

@ -1,84 +1,144 @@
import asyncio import os
import subprocess import subprocess
import textwrap import textwrap
import time
from functools import partial from functools import partial
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
from pathlib import Path from pathlib import Path
from typing import MutableMapping, Optional, Union from typing import Any, Callable, Generator, MutableMapping, Optional, Union
from urllib.parse import urlparse
import requests import requests
from Crypto.Random import get_random_bytes
from requests import Session
from requests.cookies import RequestsCookieJar, cookiejar_from_dict, get_cookie_header from requests.cookies import RequestsCookieJar, cookiejar_from_dict, get_cookie_header
from rich import filesize
from rich.text import Text from rich.text import Text
from devine.core.config import config from devine.core.config import config
from devine.core.console import console from devine.core.console import console
from devine.core.utilities import get_binary_path, start_pproxy from devine.core.utilities import get_binary_path, get_free_port
async def aria2c( def rpc(caller: Callable, secret: str, method: str, *params: Any) -> dict[str, Any]:
uri: Union[str, list[str]], """Make a call to Aria2's JSON-RPC API."""
out: Path, rpc_res = caller(
headers: Optional[dict] = None, json={
"jsonrpc": "2.0",
"id": get_random_bytes(16).hex(),
"method": method,
"params": [f"token:{secret}", *params]
}
).json()
if rpc_res.get("code"):
# wrap to console width - padding - '[Aria2c]: '
error_pretty = "\n ".join(textwrap.wrap(
f"RPC Error: {rpc_res['message']} ({rpc_res['code']})".strip(),
width=console.width - 20,
initial_indent=""
))
console.log(Text.from_ansi("\n[Aria2c]: " + error_pretty))
return rpc_res["result"]
def download(
urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]],
output_dir: Path,
filename: str,
headers: Optional[MutableMapping[str, Union[str, bytes]]] = None,
cookies: Optional[Union[MutableMapping[str, str], RequestsCookieJar]] = None, cookies: Optional[Union[MutableMapping[str, str], RequestsCookieJar]] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
silent: bool = False, max_workers: Optional[int] = None
segmented: bool = False, ) -> Generator[dict[str, Any], None, None]:
progress: Optional[partial] = None, if not urls:
*args: str raise ValueError("urls must be provided and not empty")
) -> int: elif not isinstance(urls, (str, dict, list)):
""" raise TypeError(f"Expected urls to be {str} or {dict} or a list of one of them, not {type(urls)}")
Download files using Aria2(c).
https://aria2.github.io
If multiple URLs are provided they will be downloaded in the provided order if not output_dir:
to the output directory. They will not be merged together. raise ValueError("output_dir must be provided")
""" elif not isinstance(output_dir, Path):
if not isinstance(uri, list): raise TypeError(f"Expected output_dir to be {Path}, not {type(output_dir)}")
uri = [uri]
if cookies and not isinstance(cookies, CookieJar): if not filename:
cookies = cookiejar_from_dict(cookies) raise ValueError("filename must be provided")
elif not isinstance(filename, str):
raise TypeError(f"Expected filename to be {str}, not {type(filename)}")
if not isinstance(headers, (MutableMapping, type(None))):
raise TypeError(f"Expected headers to be {MutableMapping}, not {type(headers)}")
if not isinstance(cookies, (MutableMapping, RequestsCookieJar, type(None))):
raise TypeError(f"Expected cookies to be {MutableMapping} or {RequestsCookieJar}, not {type(cookies)}")
if not isinstance(proxy, (str, type(None))):
raise TypeError(f"Expected proxy to be {str}, not {type(proxy)}")
if not max_workers:
max_workers = min(32, (os.cpu_count() or 1) + 4)
elif not isinstance(max_workers, int):
raise TypeError(f"Expected max_workers to be {int}, not {type(max_workers)}")
if not isinstance(urls, list):
urls = [urls]
executable = get_binary_path("aria2c", "aria2") executable = get_binary_path("aria2c", "aria2")
if not executable: if not executable:
raise EnvironmentError("Aria2c executable not found...") raise EnvironmentError("Aria2c executable not found...")
if proxy and proxy.lower().split(":")[0] != "http": if proxy and not proxy.lower().startswith("http://"):
# HTTPS proxies are not supported by aria2(c). raise ValueError("Only HTTP proxies are supported by aria2(c)")
# Proxy the proxy via pproxy to access it as an HTTP proxy.
async with start_pproxy(proxy) as pproxy_: if cookies and not isinstance(cookies, CookieJar):
return await aria2c(uri, out, headers, cookies, pproxy_, silent, segmented, progress, *args) cookies = cookiejar_from_dict(cookies)
multiple_urls = len(uri) > 1
url_files = [] url_files = []
for i, url in enumerate(uri): for i, url in enumerate(urls):
url_text = url if isinstance(url, str):
if multiple_urls: url_data = {
url_text += f"\n\tdir={out}" "url": url
url_text += f"\n\tout={i:08}.mp4" }
else: else:
url_text += f"\n\tdir={out.parent}" url_data: dict[str, Any] = url
url_text += f"\n\tout={out.name}" url_filename = filename.format(
i=i,
ext=Path(url_data["url"]).suffix
)
url_text = url_data["url"]
url_text += f"\n\tdir={output_dir}"
url_text += f"\n\tout={url_filename}"
if cookies: if cookies:
mock_request = requests.Request(url=url) mock_request = requests.Request(url=url_data["url"])
cookie_header = get_cookie_header(cookies, mock_request) cookie_header = get_cookie_header(cookies, mock_request)
if cookie_header: if cookie_header:
url_text += f"\n\theader=Cookie: {cookie_header}" url_text += f"\n\theader=Cookie: {cookie_header}"
for key, value in url_data.items():
if key == "url":
continue
if key == "headers":
for header_name, header_value in value.items():
url_text += f"\n\theader={header_name}: {header_value}"
else:
url_text += f"\n\t{key}={value}"
url_files.append(url_text) url_files.append(url_text)
url_file = "\n".join(url_files) url_file = "\n".join(url_files)
max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", 5)) rpc_port = get_free_port()
rpc_secret = get_random_bytes(16).hex()
rpc_uri = f"http://127.0.0.1:{rpc_port}/jsonrpc"
rpc_session = Session()
max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers))
max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1)) max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1))
split = int(config.aria2c.get("split", 5)) split = int(config.aria2c.get("split", 5))
file_allocation = config.aria2c.get("file_allocation", "prealloc") file_allocation = config.aria2c.get("file_allocation", "prealloc")
if segmented: if len(urls) > 1:
split = 1 split = 1
file_allocation = "none" file_allocation = "none"
arguments = [ arguments = [
# [Basic Options] # [Basic Options]
"--input-file", "-", "--input-file", "-",
"--out", out.name,
"--all-proxy", proxy or "", "--all-proxy", proxy or "",
"--continue=true", "--continue=true",
# [Connection Options] # [Connection Options]
@ -92,11 +152,13 @@ async def aria2c(
"--allow-overwrite=true", "--allow-overwrite=true",
"--auto-file-renaming=false", "--auto-file-renaming=false",
"--console-log-level=warn", "--console-log-level=warn",
f"--download-result={'default' if progress else 'hide'}", "--download-result=default",
f"--file-allocation={file_allocation}", f"--file-allocation={file_allocation}",
"--summary-interval=0", "--summary-interval=0",
# [Extra Options] # [RPC Options]
*args "--enable-rpc=true",
f"--rpc-listen-port={rpc_port}",
f"--rpc-secret={rpc_secret}"
] ]
for header, value in (headers or {}).items(): for header, value in (headers or {}).items():
@ -115,66 +177,44 @@ async def aria2c(
arguments.extend(["--header", f"{header}: {value}"]) arguments.extend(["--header", f"{header}: {value}"])
try: try:
p = await asyncio.create_subprocess_exec( p = subprocess.Popen(
executable, [
*arguments, executable,
*arguments
],
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
stdout=subprocess.PIPE stdout=subprocess.DEVNULL
) )
p.stdin.write(url_file.encode()) p.stdin.write(url_file.encode())
await p.stdin.drain()
p.stdin.close() p.stdin.close()
if p.stdout: while p.poll() is None:
is_dl_summary = False global_stats = rpc(
log_buffer = "" caller=partial(rpc_session.post, url=rpc_uri),
while True: secret=rpc_secret,
try: method="aria2.getGlobalStat"
chunk = await p.stdout.readuntil(b"\r") )
except asyncio.IncompleteReadError as e: if global_stats:
chunk = e.partial active = int(global_stats["numActive"])
if not chunk: waiting = int(global_stats["numWaiting"])
stopped = int(global_stats["numStopped"])
total = active + waiting + stopped
yield dict(
total=total,
completed=stopped,
downloaded=f"{filesize.decimal(int(global_stats['downloadSpeed']))}/s"
)
if total == stopped:
rpc(
caller=partial(rpc_session.post, url=rpc_uri),
secret=rpc_secret,
method="aria2.shutdown"
)
break break
for line in chunk.decode().strip().splitlines(): time.sleep(1)
if not line:
continue
if line.startswith("Download Results"):
# we know it's 100% downloaded, but let's use the avg dl speed value
is_dl_summary = True
elif line.startswith("[") and line.endswith("]"):
if progress and "%" in line:
# id, dledMiB/totalMiB(x%), CN:xx, DL:xxMiB, ETA:Xs
# eta may not always be available
data_parts = line[1:-1].split()
perc_parts = data_parts[1].split("(")
if len(perc_parts) == 2:
# might otherwise be e.g., 0B/0B, with no % symbol provided
progress(
total=100,
completed=int(perc_parts[1][:-2]),
downloaded=f"{data_parts[3].split(':')[1]}/s"
)
elif is_dl_summary and "OK" in line and "|" in line:
gid, status, avg_speed, path_or_uri = line.split("|")
progress(total=100, completed=100, downloaded=avg_speed.strip())
elif not is_dl_summary:
if "aria2 will resume download if the transfer is restarted" in line:
continue
if "If there are any errors, then see the log file" in line:
continue
log_buffer += f"{line.strip()}\n"
if log_buffer and not silent: p.wait()
# wrap to console width - padding - '[Aria2c]: '
log_buffer = "\n ".join(textwrap.wrap(
log_buffer.rstrip(),
width=console.width - 20,
initial_indent=""
))
console.log(Text.from_ansi("\n[Aria2c]: " + log_buffer))
await p.wait()
if p.returncode != 0: if p.returncode != 0:
raise subprocess.CalledProcessError(p.returncode, arguments) raise subprocess.CalledProcessError(p.returncode, arguments)
@ -188,7 +228,81 @@ async def aria2c(
raise KeyboardInterrupt() raise KeyboardInterrupt()
raise raise
return p.returncode
def aria2c(
urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]],
output_dir: Path,
filename: str,
headers: Optional[MutableMapping[str, Union[str, bytes]]] = None,
cookies: Optional[Union[MutableMapping[str, str], RequestsCookieJar]] = None,
proxy: Optional[str] = None,
max_workers: Optional[int] = None
) -> Generator[dict[str, Any], None, None]:
"""
Download files using Aria2(c).
https://aria2.github.io
Yields the following download status updates while chunks are downloading:
- {total: 100} (100% download total)
- {completed: 1} (1% download progress out of 100%)
- {downloaded: "10.1 MB/s"} (currently downloading at a rate of 10.1 MB/s)
The data is in the same format accepted by rich's progress.update() function.
Parameters:
urls: Web URL(s) to file(s) to download. You can use a dictionary with the key
"url" for the URI, and other keys for extra arguments to use per-URL.
output_dir: The folder to save the file into. If the save path's directory does
not exist then it will be made automatically.
filename: The filename or filename template to use for each file. The variables
you can use are `i` for the URL index and `ext` for the URL extension.
headers: A mapping of HTTP Header Key/Values to use for all downloads.
cookies: A mapping of Cookie Key/Values or a Cookie Jar to use for all downloads.
proxy: An optional proxy URI to route connections through for all downloads.
max_workers: The maximum amount of threads to use for downloads. Defaults to
min(32,(cpu_count+4)). Use for the --max-concurrent-downloads option.
"""
if proxy and not proxy.lower().startswith("http://"):
# Only HTTP proxies are supported by aria2(c)
proxy = urlparse(proxy)
port = get_free_port()
username, password = get_random_bytes(8).hex(), get_random_bytes(8).hex()
local_proxy = f"http://{username}:{password}@localhost:{port}"
scheme = {
"https": "http+ssl",
"socks5h": "socks"
}.get(proxy.scheme, proxy.scheme)
remote_server = f"{scheme}://{proxy.hostname}"
if proxy.port:
remote_server += f":{proxy.port}"
if proxy.username or proxy.password:
remote_server += "#"
if proxy.username:
remote_server += proxy.username
if proxy.password:
remote_server += f":{proxy.password}"
p = subprocess.Popen(
[
"pproxy",
"-l", f"http://:{port}#{username}:{password}",
"-r", remote_server
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL
)
try:
yield from download(urls, output_dir, filename, headers, cookies, local_proxy, max_workers)
finally:
p.kill()
p.wait()
return
yield from download(urls, output_dir, filename, headers, cookies, proxy, max_workers)
__all__ = ("aria2c",) __all__ = ("aria2c",)

View File

@ -3,6 +3,7 @@ import contextlib
import importlib.util import importlib.util
import re import re
import shutil import shutil
import socket
import sys import sys
import time import time
import unicodedata import unicodedata
@ -10,11 +11,9 @@ from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from types import ModuleType from types import ModuleType
from typing import AsyncIterator, Optional, Sequence, Union from typing import Optional, Sequence, Union
from urllib.parse import urlparse
import chardet import chardet
import pproxy
import requests import requests
from construct import ValidationError from construct import ValidationError
from langcodes import Language, closest_match from langcodes import Language, closest_match
@ -244,35 +243,17 @@ def try_ensure_utf8(data: bytes) -> bytes:
return data return data
@contextlib.asynccontextmanager def get_free_port() -> int:
async def start_pproxy(proxy: str) -> AsyncIterator[str]: """
proxy = urlparse(proxy) Get an available port to use between a-b (inclusive).
scheme = { The port is freed as soon as this has returned, therefore, it
"https": "http+ssl", is possible for the port to be taken before you try to use it.
"socks5h": "socks" """
}.get(proxy.scheme, proxy.scheme) with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
remote_server = f"{scheme}://{proxy.hostname}" s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if proxy.port: return s.getsockname()[1]
remote_server += f":{proxy.port}"
if proxy.username or proxy.password:
remote_server += "#"
if proxy.username:
remote_server += proxy.username
if proxy.password:
remote_server += f":{proxy.password}"
server = pproxy.Server("http://localhost:0") # random port
remote = pproxy.Connection(remote_server)
handler = await server.start_server({"rserver": [remote]})
try:
port = handler.sockets[0].getsockname()[1]
yield f"http://localhost:{port}"
finally:
handler.close()
await handler.wait_closed()
class FPS(ast.NodeVisitor): class FPS(ast.NodeVisitor):