From 994ab152a40f9c53881bc9bf2280156a7193825c Mon Sep 17 00:00:00 2001 From: rlaphoenix Date: Fri, 5 Apr 2024 00:37:44 +0100 Subject: [PATCH] fix(requests): Fix multithreaded downloads For some reason moving the download speed calculation code from the requests() function to the download() function makes it actually multi-threaded instead of sequential downloads. --- devine/core/downloaders/requests.py | 153 ++++++++++++++-------------- 1 file changed, 75 insertions(+), 78 deletions(-) diff --git a/devine/core/downloaders/requests.py b/devine/core/downloaders/requests.py index c6d23fc..3e47c18 100644 --- a/devine/core/downloaders/requests.py +++ b/devine/core/downloaders/requests.py @@ -1,7 +1,7 @@ import math import os import time -from concurrent import futures +from concurrent.futures import as_completed from concurrent.futures.thread import ThreadPoolExecutor from http.cookiejar import CookieJar from pathlib import Path @@ -19,11 +19,14 @@ RETRY_WAIT = 2 CHUNK_SIZE = 1024 PROGRESS_WINDOW = 5 +DOWNLOAD_SIZES = [] +LAST_SPEED_REFRESH = time.time() def download( url: str, save_path: Path, session: Optional[Session] = None, + segmented: bool = False, **kwargs: Any ) -> Generator[dict[str, Any], None, None]: """ @@ -48,10 +51,13 @@ def download( session: The Requests Session to make HTTP requests with. Useful to set Header, Cookie, and Proxy data. Connections are saved and re-used with the session so long as the server keeps the connection alive. + segmented: If downloads are segments or parts of one bigger file. kwargs: Any extra keyword arguments to pass to the session.get() call. Use this for one-time request changes like a header, cookie, or proxy. For example, to request Byte-ranges use e.g., `headers={"Range": "bytes=0-128"}`. """ + global LAST_SPEED_REFRESH + session = session or Session() save_dir = save_path.parent @@ -69,6 +75,7 @@ def download( file_downloaded=save_path, written=save_path.stat().st_size ) + # TODO: This should return, potential recovery bug # TODO: Design a control file format so we know how much of the file is missing control_file.write_bytes(b"") @@ -77,6 +84,8 @@ def download( try: while True: written = 0 + + # these are for single-url speed calcs only download_sizes = [] last_speed_refresh = time.time() @@ -84,16 +93,17 @@ def download( stream = session.get(url, stream=True, **kwargs) stream.raise_for_status() - try: - content_length = int(stream.headers.get("Content-Length", "0")) - except ValueError: - content_length = 0 + if not segmented: + try: + content_length = int(stream.headers.get("Content-Length", "0")) + except ValueError: + content_length = 0 - if content_length > 0: - yield dict(total=math.ceil(content_length / CHUNK_SIZE)) - else: - # we have no data to calculate total chunks - yield dict(total=None) # indeterminate mode + if content_length > 0: + yield dict(total=math.ceil(content_length / CHUNK_SIZE)) + else: + # we have no data to calculate total chunks + yield dict(total=None) # indeterminate mode with open(save_path, "wb") as f: for chunk in stream.iter_content(chunk_size=CHUNK_SIZE): @@ -101,23 +111,32 @@ def download( f.write(chunk) written += download_size - yield dict(advance=1) + if not segmented: + yield dict(advance=1) + now = time.time() + time_since = now - last_speed_refresh + download_sizes.append(download_size) + if time_since > PROGRESS_WINDOW or download_size < CHUNK_SIZE: + data_size = sum(download_sizes) + download_speed = math.ceil(data_size / (time_since or 1)) + yield dict(downloaded=f"{filesize.decimal(download_speed)}/s") + last_speed_refresh = now + download_sizes.clear() - now = time.time() - time_since = now - last_speed_refresh + yield dict(file_downloaded=save_path, written=written) - download_sizes.append(download_size) - if time_since > PROGRESS_WINDOW or download_size < CHUNK_SIZE: - data_size = sum(download_sizes) - download_speed = math.ceil(data_size / (time_since or 1)) - yield dict(downloaded=f"{filesize.decimal(download_speed)}/s") - last_speed_refresh = now - download_sizes.clear() - - yield dict( - file_downloaded=save_path, - written=written - ) + if segmented: + yield dict(advance=1) + now = time.time() + time_since = now - LAST_SPEED_REFRESH + if written: # no size == skipped dl + DOWNLOAD_SIZES.append(written) + if DOWNLOAD_SIZES and time_since > PROGRESS_WINDOW: + data_size = sum(DOWNLOAD_SIZES) + download_speed = math.ceil(data_size / (time_since or 1)) + yield dict(downloaded=f"{filesize.decimal(download_speed)}/s") + LAST_SPEED_REFRESH = now + DOWNLOAD_SIZES.clear() break except Exception as e: save_path.unlink(missing_ok=True) @@ -237,59 +256,37 @@ def requests( yield dict(total=len(urls)) - download_sizes = [] - last_speed_refresh = time.time() - - with ThreadPoolExecutor(max_workers=max_workers) as pool: - for i, future in enumerate(futures.as_completed(( - pool.submit( - download, - session=session, - **url - ) - for url in urls - ))): - file_path, download_size = None, None - try: - for status_update in future.result(): - if status_update.get("file_downloaded") and status_update.get("written"): - file_path = status_update["file_downloaded"] - download_size = status_update["written"] - elif len(urls) == 1: - # these are per-chunk updates, only useful if it's one big file - yield status_update - except KeyboardInterrupt: - DOWNLOAD_CANCELLED.set() # skip pending track downloads - yield dict(downloaded="[yellow]CANCELLING") - pool.shutdown(wait=True, cancel_futures=True) - yield dict(downloaded="[yellow]CANCELLED") - # tell dl that it was cancelled - # the pool is already shut down, so exiting loop is fine - raise - except Exception: - DOWNLOAD_CANCELLED.set() # skip pending track downloads - yield dict(downloaded="[red]FAILING") - pool.shutdown(wait=True, cancel_futures=True) - yield dict(downloaded="[red]FAILED") - # tell dl that it failed - # the pool is already shut down, so exiting loop is fine - raise - else: - yield dict(file_downloaded=file_path, written=download_size) - yield dict(advance=1) - - now = time.time() - time_since = now - last_speed_refresh - - if download_size: # no size == skipped dl - download_sizes.append(download_size) - - if download_sizes and (time_since > PROGRESS_WINDOW or i == len(urls)): - data_size = sum(download_sizes) - download_speed = math.ceil(data_size / (time_since or 1)) - yield dict(downloaded=f"{filesize.decimal(download_speed)}/s") - last_speed_refresh = now - download_sizes.clear() + try: + with ThreadPoolExecutor(max_workers=max_workers) as pool: + for future in as_completed( + pool.submit( + download, + session=session, + segmented=True, + **url + ) + for url in urls + ): + try: + yield from future.result() + except KeyboardInterrupt: + DOWNLOAD_CANCELLED.set() # skip pending track downloads + yield dict(downloaded="[yellow]CANCELLING") + pool.shutdown(wait=True, cancel_futures=True) + yield dict(downloaded="[yellow]CANCELLED") + # tell dl that it was cancelled + # the pool is already shut down, so exiting loop is fine + raise + except Exception: + DOWNLOAD_CANCELLED.set() # skip pending track downloads + yield dict(downloaded="[red]FAILING") + pool.shutdown(wait=True, cancel_futures=True) + yield dict(downloaded="[red]FAILED") + # tell dl that it failed + # the pool is already shut down, so exiting loop is fine + raise + finally: + DOWNLOAD_SIZES.clear() __all__ = ("requests",)