fix(requests): Fix multithreaded downloads

For some reason moving the download speed calculation code from the requests() function to the download() function makes it actually multi-threaded instead of sequential downloads.
This commit is contained in:
rlaphoenix 2024-04-05 00:37:44 +01:00
parent 5d1b54b8fa
commit 994ab152a4
1 changed files with 75 additions and 78 deletions

View File

@ -1,7 +1,7 @@
import math import math
import os import os
import time import time
from concurrent import futures from concurrent.futures import as_completed
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
from pathlib import Path from pathlib import Path
@ -19,11 +19,14 @@ RETRY_WAIT = 2
CHUNK_SIZE = 1024 CHUNK_SIZE = 1024
PROGRESS_WINDOW = 5 PROGRESS_WINDOW = 5
DOWNLOAD_SIZES = []
LAST_SPEED_REFRESH = time.time()
def download( def download(
url: str, url: str,
save_path: Path, save_path: Path,
session: Optional[Session] = None, session: Optional[Session] = None,
segmented: bool = False,
**kwargs: Any **kwargs: Any
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
""" """
@ -48,10 +51,13 @@ def download(
session: The Requests Session to make HTTP requests with. Useful to set Header, session: The Requests Session to make HTTP requests with. Useful to set Header,
Cookie, and Proxy data. Connections are saved and re-used with the session Cookie, and Proxy data. Connections are saved and re-used with the session
so long as the server keeps the connection alive. so long as the server keeps the connection alive.
segmented: If downloads are segments or parts of one bigger file.
kwargs: Any extra keyword arguments to pass to the session.get() call. Use this kwargs: Any extra keyword arguments to pass to the session.get() call. Use this
for one-time request changes like a header, cookie, or proxy. For example, for one-time request changes like a header, cookie, or proxy. For example,
to request Byte-ranges use e.g., `headers={"Range": "bytes=0-128"}`. to request Byte-ranges use e.g., `headers={"Range": "bytes=0-128"}`.
""" """
global LAST_SPEED_REFRESH
session = session or Session() session = session or Session()
save_dir = save_path.parent save_dir = save_path.parent
@ -69,6 +75,7 @@ def download(
file_downloaded=save_path, file_downloaded=save_path,
written=save_path.stat().st_size written=save_path.stat().st_size
) )
# TODO: This should return, potential recovery bug
# TODO: Design a control file format so we know how much of the file is missing # TODO: Design a control file format so we know how much of the file is missing
control_file.write_bytes(b"") control_file.write_bytes(b"")
@ -77,6 +84,8 @@ def download(
try: try:
while True: while True:
written = 0 written = 0
# these are for single-url speed calcs only
download_sizes = [] download_sizes = []
last_speed_refresh = time.time() last_speed_refresh = time.time()
@ -84,6 +93,7 @@ def download(
stream = session.get(url, stream=True, **kwargs) stream = session.get(url, stream=True, **kwargs)
stream.raise_for_status() stream.raise_for_status()
if not segmented:
try: try:
content_length = int(stream.headers.get("Content-Length", "0")) content_length = int(stream.headers.get("Content-Length", "0"))
except ValueError: except ValueError:
@ -101,11 +111,10 @@ def download(
f.write(chunk) f.write(chunk)
written += download_size written += download_size
if not segmented:
yield dict(advance=1) yield dict(advance=1)
now = time.time() now = time.time()
time_since = now - last_speed_refresh time_since = now - last_speed_refresh
download_sizes.append(download_size) download_sizes.append(download_size)
if time_since > PROGRESS_WINDOW or download_size < CHUNK_SIZE: if time_since > PROGRESS_WINDOW or download_size < CHUNK_SIZE:
data_size = sum(download_sizes) data_size = sum(download_sizes)
@ -114,10 +123,20 @@ def download(
last_speed_refresh = now last_speed_refresh = now
download_sizes.clear() download_sizes.clear()
yield dict( yield dict(file_downloaded=save_path, written=written)
file_downloaded=save_path,
written=written if segmented:
) yield dict(advance=1)
now = time.time()
time_since = now - LAST_SPEED_REFRESH
if written: # no size == skipped dl
DOWNLOAD_SIZES.append(written)
if DOWNLOAD_SIZES and time_since > PROGRESS_WINDOW:
data_size = sum(DOWNLOAD_SIZES)
download_speed = math.ceil(data_size / (time_since or 1))
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
LAST_SPEED_REFRESH = now
DOWNLOAD_SIZES.clear()
break break
except Exception as e: except Exception as e:
save_path.unlink(missing_ok=True) save_path.unlink(missing_ok=True)
@ -237,27 +256,19 @@ def requests(
yield dict(total=len(urls)) yield dict(total=len(urls))
download_sizes = [] try:
last_speed_refresh = time.time()
with ThreadPoolExecutor(max_workers=max_workers) as pool: with ThreadPoolExecutor(max_workers=max_workers) as pool:
for i, future in enumerate(futures.as_completed(( for future in as_completed(
pool.submit( pool.submit(
download, download,
session=session, session=session,
segmented=True,
**url **url
) )
for url in urls for url in urls
))): ):
file_path, download_size = None, None
try: try:
for status_update in future.result(): yield from future.result()
if status_update.get("file_downloaded") and status_update.get("written"):
file_path = status_update["file_downloaded"]
download_size = status_update["written"]
elif len(urls) == 1:
# these are per-chunk updates, only useful if it's one big file
yield status_update
except KeyboardInterrupt: except KeyboardInterrupt:
DOWNLOAD_CANCELLED.set() # skip pending track downloads DOWNLOAD_CANCELLED.set() # skip pending track downloads
yield dict(downloaded="[yellow]CANCELLING") yield dict(downloaded="[yellow]CANCELLING")
@ -274,22 +285,8 @@ def requests(
# tell dl that it failed # tell dl that it failed
# the pool is already shut down, so exiting loop is fine # the pool is already shut down, so exiting loop is fine
raise raise
else: finally:
yield dict(file_downloaded=file_path, written=download_size) DOWNLOAD_SIZES.clear()
yield dict(advance=1)
now = time.time()
time_since = now - last_speed_refresh
if download_size: # no size == skipped dl
download_sizes.append(download_size)
if download_sizes and (time_since > PROGRESS_WINDOW or i == len(urls)):
data_size = sum(download_sizes)
download_speed = math.ceil(data_size / (time_since or 1))
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
__all__ = ("requests",) __all__ = ("requests",)