Rework the Aria2c Downloader

- Downloads are now multithreaded directly in the downloader.
- Now reuses connections instead of having to close and reopen connections for every single download.
- Progress updates are now yielded back to the caller instead of drilling down a progress callable.
- Instead of parsing download progress information in a very hacky way from the stdout stream, use aria2's RPC interface.
- Added a new utility get_free_port which is needed to choose aria2's RPC port as I do not want to use the default port in case the user is already using this port for another tool or reason. Also, to try mitigate port scanning attacks that target aria2 RPC ports.
- The config entry `aria2c.max_concurrent_downloads` is now actually used by aria2c when downloading.
- The `--max-concurrent-downloads` option and config value now defaults to `min(32,(cpu_count+4))` (usually around 16 for above average systems) instead of 5.
- Automated pproxy proxy rerouter is made via subprocess instead of trying to re-do what the pproxy entry point does for us, less code, less trouble, and was ultimately easier to implement.
This commit is contained in:
rlaphoenix 2024-02-15 16:07:42 +00:00
parent 2b7fc929f6
commit 630a9906ce
4 changed files with 225 additions and 133 deletions

View File

@ -11,13 +11,12 @@ which does not keep comments.
## aria2c (dict)
- `max_concurrent_downloads`
Maximum number of parallel downloads. Default: `5`
Note: Currently unused as downloads are multi-threaded by Devine rather than Aria2c.
Devine internally has a constant set value of 16 for it's parallel downloads.
Maximum number of parallel downloads. Default: `min(32,(cpu_count+4))`
Note: Overrides the `max_workers` parameter of the aria2(c) downloader function.
- `max_connection_per_server`
Maximum number of connections to one server for each download. Default: `1`
- `split`
Split a file into N chunks and download each chunk on it's own connection. Default: `5`
Split a file into N chunks and download each chunk on its own connection. Default: `5`
- `file_allocation`
Specify file allocation method. Default: `"prealloc"`

View File

@ -1,12 +1,10 @@
import asyncio
from ..config import config
from .aria2c import aria2c
from .curl_impersonate import curl_impersonate
from .requests import requests
downloader = {
"aria2c": lambda *args, **kwargs: asyncio.run(aria2c(*args, **kwargs)),
"aria2c": aria2c,
"curl_impersonate": curl_impersonate,
"requests": requests
}[config.downloader]

View File

@ -1,84 +1,144 @@
import asyncio
import os
import subprocess
import textwrap
import time
from functools import partial
from http.cookiejar import CookieJar
from pathlib import Path
from typing import MutableMapping, Optional, Union
from typing import Any, Callable, Generator, MutableMapping, Optional, Union
from urllib.parse import urlparse
import requests
from Crypto.Random import get_random_bytes
from requests import Session
from requests.cookies import RequestsCookieJar, cookiejar_from_dict, get_cookie_header
from rich import filesize
from rich.text import Text
from devine.core.config import config
from devine.core.console import console
from devine.core.utilities import get_binary_path, start_pproxy
from devine.core.utilities import get_binary_path, get_free_port
async def aria2c(
uri: Union[str, list[str]],
out: Path,
headers: Optional[dict] = None,
def rpc(caller: Callable, secret: str, method: str, *params: Any) -> dict[str, Any]:
"""Make a call to Aria2's JSON-RPC API."""
rpc_res = caller(
json={
"jsonrpc": "2.0",
"id": get_random_bytes(16).hex(),
"method": method,
"params": [f"token:{secret}", *params]
}
).json()
if rpc_res.get("code"):
# wrap to console width - padding - '[Aria2c]: '
error_pretty = "\n ".join(textwrap.wrap(
f"RPC Error: {rpc_res['message']} ({rpc_res['code']})".strip(),
width=console.width - 20,
initial_indent=""
))
console.log(Text.from_ansi("\n[Aria2c]: " + error_pretty))
return rpc_res["result"]
def download(
urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]],
output_dir: Path,
filename: str,
headers: Optional[MutableMapping[str, Union[str, bytes]]] = None,
cookies: Optional[Union[MutableMapping[str, str], RequestsCookieJar]] = None,
proxy: Optional[str] = None,
silent: bool = False,
segmented: bool = False,
progress: Optional[partial] = None,
*args: str
) -> int:
"""
Download files using Aria2(c).
https://aria2.github.io
max_workers: Optional[int] = None
) -> Generator[dict[str, Any], None, None]:
if not urls:
raise ValueError("urls must be provided and not empty")
elif not isinstance(urls, (str, dict, list)):
raise TypeError(f"Expected urls to be {str} or {dict} or a list of one of them, not {type(urls)}")
If multiple URLs are provided they will be downloaded in the provided order
to the output directory. They will not be merged together.
"""
if not isinstance(uri, list):
uri = [uri]
if not output_dir:
raise ValueError("output_dir must be provided")
elif not isinstance(output_dir, Path):
raise TypeError(f"Expected output_dir to be {Path}, not {type(output_dir)}")
if cookies and not isinstance(cookies, CookieJar):
cookies = cookiejar_from_dict(cookies)
if not filename:
raise ValueError("filename must be provided")
elif not isinstance(filename, str):
raise TypeError(f"Expected filename to be {str}, not {type(filename)}")
if not isinstance(headers, (MutableMapping, type(None))):
raise TypeError(f"Expected headers to be {MutableMapping}, not {type(headers)}")
if not isinstance(cookies, (MutableMapping, RequestsCookieJar, type(None))):
raise TypeError(f"Expected cookies to be {MutableMapping} or {RequestsCookieJar}, not {type(cookies)}")
if not isinstance(proxy, (str, type(None))):
raise TypeError(f"Expected proxy to be {str}, not {type(proxy)}")
if not max_workers:
max_workers = min(32, (os.cpu_count() or 1) + 4)
elif not isinstance(max_workers, int):
raise TypeError(f"Expected max_workers to be {int}, not {type(max_workers)}")
if not isinstance(urls, list):
urls = [urls]
executable = get_binary_path("aria2c", "aria2")
if not executable:
raise EnvironmentError("Aria2c executable not found...")
if proxy and proxy.lower().split(":")[0] != "http":
# HTTPS proxies are not supported by aria2(c).
# Proxy the proxy via pproxy to access it as an HTTP proxy.
async with start_pproxy(proxy) as pproxy_:
return await aria2c(uri, out, headers, cookies, pproxy_, silent, segmented, progress, *args)
if proxy and not proxy.lower().startswith("http://"):
raise ValueError("Only HTTP proxies are supported by aria2(c)")
if cookies and not isinstance(cookies, CookieJar):
cookies = cookiejar_from_dict(cookies)
multiple_urls = len(uri) > 1
url_files = []
for i, url in enumerate(uri):
url_text = url
if multiple_urls:
url_text += f"\n\tdir={out}"
url_text += f"\n\tout={i:08}.mp4"
for i, url in enumerate(urls):
if isinstance(url, str):
url_data = {
"url": url
}
else:
url_text += f"\n\tdir={out.parent}"
url_text += f"\n\tout={out.name}"
url_data: dict[str, Any] = url
url_filename = filename.format(
i=i,
ext=Path(url_data["url"]).suffix
)
url_text = url_data["url"]
url_text += f"\n\tdir={output_dir}"
url_text += f"\n\tout={url_filename}"
if cookies:
mock_request = requests.Request(url=url)
mock_request = requests.Request(url=url_data["url"])
cookie_header = get_cookie_header(cookies, mock_request)
if cookie_header:
url_text += f"\n\theader=Cookie: {cookie_header}"
for key, value in url_data.items():
if key == "url":
continue
if key == "headers":
for header_name, header_value in value.items():
url_text += f"\n\theader={header_name}: {header_value}"
else:
url_text += f"\n\t{key}={value}"
url_files.append(url_text)
url_file = "\n".join(url_files)
max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", 5))
rpc_port = get_free_port()
rpc_secret = get_random_bytes(16).hex()
rpc_uri = f"http://127.0.0.1:{rpc_port}/jsonrpc"
rpc_session = Session()
max_concurrent_downloads = int(config.aria2c.get("max_concurrent_downloads", max_workers))
max_connection_per_server = int(config.aria2c.get("max_connection_per_server", 1))
split = int(config.aria2c.get("split", 5))
file_allocation = config.aria2c.get("file_allocation", "prealloc")
if segmented:
if len(urls) > 1:
split = 1
file_allocation = "none"
arguments = [
# [Basic Options]
"--input-file", "-",
"--out", out.name,
"--all-proxy", proxy or "",
"--continue=true",
# [Connection Options]
@ -92,11 +152,13 @@ async def aria2c(
"--allow-overwrite=true",
"--auto-file-renaming=false",
"--console-log-level=warn",
f"--download-result={'default' if progress else 'hide'}",
"--download-result=default",
f"--file-allocation={file_allocation}",
"--summary-interval=0",
# [Extra Options]
*args
# [RPC Options]
"--enable-rpc=true",
f"--rpc-listen-port={rpc_port}",
f"--rpc-secret={rpc_secret}"
]
for header, value in (headers or {}).items():
@ -115,66 +177,44 @@ async def aria2c(
arguments.extend(["--header", f"{header}: {value}"])
try:
p = await asyncio.create_subprocess_exec(
p = subprocess.Popen(
[
executable,
*arguments,
*arguments
],
stdin=subprocess.PIPE,
stdout=subprocess.PIPE
stdout=subprocess.DEVNULL
)
p.stdin.write(url_file.encode())
await p.stdin.drain()
p.stdin.close()
if p.stdout:
is_dl_summary = False
log_buffer = ""
while True:
try:
chunk = await p.stdout.readuntil(b"\r")
except asyncio.IncompleteReadError as e:
chunk = e.partial
if not chunk:
break
for line in chunk.decode().strip().splitlines():
if not line:
continue
if line.startswith("Download Results"):
# we know it's 100% downloaded, but let's use the avg dl speed value
is_dl_summary = True
elif line.startswith("[") and line.endswith("]"):
if progress and "%" in line:
# id, dledMiB/totalMiB(x%), CN:xx, DL:xxMiB, ETA:Xs
# eta may not always be available
data_parts = line[1:-1].split()
perc_parts = data_parts[1].split("(")
if len(perc_parts) == 2:
# might otherwise be e.g., 0B/0B, with no % symbol provided
progress(
total=100,
completed=int(perc_parts[1][:-2]),
downloaded=f"{data_parts[3].split(':')[1]}/s"
while p.poll() is None:
global_stats = rpc(
caller=partial(rpc_session.post, url=rpc_uri),
secret=rpc_secret,
method="aria2.getGlobalStat"
)
elif is_dl_summary and "OK" in line and "|" in line:
gid, status, avg_speed, path_or_uri = line.split("|")
progress(total=100, completed=100, downloaded=avg_speed.strip())
elif not is_dl_summary:
if "aria2 will resume download if the transfer is restarted" in line:
continue
if "If there are any errors, then see the log file" in line:
continue
log_buffer += f"{line.strip()}\n"
if global_stats:
active = int(global_stats["numActive"])
waiting = int(global_stats["numWaiting"])
stopped = int(global_stats["numStopped"])
total = active + waiting + stopped
yield dict(
total=total,
completed=stopped,
downloaded=f"{filesize.decimal(int(global_stats['downloadSpeed']))}/s"
)
if total == stopped:
rpc(
caller=partial(rpc_session.post, url=rpc_uri),
secret=rpc_secret,
method="aria2.shutdown"
)
break
time.sleep(1)
if log_buffer and not silent:
# wrap to console width - padding - '[Aria2c]: '
log_buffer = "\n ".join(textwrap.wrap(
log_buffer.rstrip(),
width=console.width - 20,
initial_indent=""
))
console.log(Text.from_ansi("\n[Aria2c]: " + log_buffer))
await p.wait()
p.wait()
if p.returncode != 0:
raise subprocess.CalledProcessError(p.returncode, arguments)
@ -188,7 +228,81 @@ async def aria2c(
raise KeyboardInterrupt()
raise
return p.returncode
def aria2c(
urls: Union[str, list[str], dict[str, Any], list[dict[str, Any]]],
output_dir: Path,
filename: str,
headers: Optional[MutableMapping[str, Union[str, bytes]]] = None,
cookies: Optional[Union[MutableMapping[str, str], RequestsCookieJar]] = None,
proxy: Optional[str] = None,
max_workers: Optional[int] = None
) -> Generator[dict[str, Any], None, None]:
"""
Download files using Aria2(c).
https://aria2.github.io
Yields the following download status updates while chunks are downloading:
- {total: 100} (100% download total)
- {completed: 1} (1% download progress out of 100%)
- {downloaded: "10.1 MB/s"} (currently downloading at a rate of 10.1 MB/s)
The data is in the same format accepted by rich's progress.update() function.
Parameters:
urls: Web URL(s) to file(s) to download. You can use a dictionary with the key
"url" for the URI, and other keys for extra arguments to use per-URL.
output_dir: The folder to save the file into. If the save path's directory does
not exist then it will be made automatically.
filename: The filename or filename template to use for each file. The variables
you can use are `i` for the URL index and `ext` for the URL extension.
headers: A mapping of HTTP Header Key/Values to use for all downloads.
cookies: A mapping of Cookie Key/Values or a Cookie Jar to use for all downloads.
proxy: An optional proxy URI to route connections through for all downloads.
max_workers: The maximum amount of threads to use for downloads. Defaults to
min(32,(cpu_count+4)). Use for the --max-concurrent-downloads option.
"""
if proxy and not proxy.lower().startswith("http://"):
# Only HTTP proxies are supported by aria2(c)
proxy = urlparse(proxy)
port = get_free_port()
username, password = get_random_bytes(8).hex(), get_random_bytes(8).hex()
local_proxy = f"http://{username}:{password}@localhost:{port}"
scheme = {
"https": "http+ssl",
"socks5h": "socks"
}.get(proxy.scheme, proxy.scheme)
remote_server = f"{scheme}://{proxy.hostname}"
if proxy.port:
remote_server += f":{proxy.port}"
if proxy.username or proxy.password:
remote_server += "#"
if proxy.username:
remote_server += proxy.username
if proxy.password:
remote_server += f":{proxy.password}"
p = subprocess.Popen(
[
"pproxy",
"-l", f"http://:{port}#{username}:{password}",
"-r", remote_server
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL
)
try:
yield from download(urls, output_dir, filename, headers, cookies, local_proxy, max_workers)
finally:
p.kill()
p.wait()
return
yield from download(urls, output_dir, filename, headers, cookies, proxy, max_workers)
__all__ = ("aria2c",)

View File

@ -3,6 +3,7 @@ import contextlib
import importlib.util
import re
import shutil
import socket
import sys
import time
import unicodedata
@ -10,11 +11,9 @@ from collections import defaultdict
from datetime import datetime
from pathlib import Path
from types import ModuleType
from typing import AsyncIterator, Optional, Sequence, Union
from urllib.parse import urlparse
from typing import Optional, Sequence, Union
import chardet
import pproxy
import requests
from construct import ValidationError
from langcodes import Language, closest_match
@ -244,35 +243,17 @@ def try_ensure_utf8(data: bytes) -> bytes:
return data
@contextlib.asynccontextmanager
async def start_pproxy(proxy: str) -> AsyncIterator[str]:
proxy = urlparse(proxy)
def get_free_port() -> int:
"""
Get an available port to use between a-b (inclusive).
scheme = {
"https": "http+ssl",
"socks5h": "socks"
}.get(proxy.scheme, proxy.scheme)
remote_server = f"{scheme}://{proxy.hostname}"
if proxy.port:
remote_server += f":{proxy.port}"
if proxy.username or proxy.password:
remote_server += "#"
if proxy.username:
remote_server += proxy.username
if proxy.password:
remote_server += f":{proxy.password}"
server = pproxy.Server("http://localhost:0") # random port
remote = pproxy.Connection(remote_server)
handler = await server.start_server({"rserver": [remote]})
try:
port = handler.sockets[0].getsockname()[1]
yield f"http://localhost:{port}"
finally:
handler.close()
await handler.wait_closed()
The port is freed as soon as this has returned, therefore, it
is possible for the port to be taken before you try to use it.
"""
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]
class FPS(ast.NodeVisitor):