fix(requests): Fix multithreaded downloads

For some reason moving the download speed calculation code from the requests() function to the download() function makes it actually multi-threaded instead of sequential downloads.
This commit is contained in:
rlaphoenix 2024-04-05 00:37:44 +01:00
parent 5d1b54b8fa
commit 994ab152a4
1 changed files with 75 additions and 78 deletions

View File

@ -1,7 +1,7 @@
import math import math
import os import os
import time import time
from concurrent import futures from concurrent.futures import as_completed
from concurrent.futures.thread import ThreadPoolExecutor from concurrent.futures.thread import ThreadPoolExecutor
from http.cookiejar import CookieJar from http.cookiejar import CookieJar
from pathlib import Path from pathlib import Path
@ -19,11 +19,14 @@ RETRY_WAIT = 2
CHUNK_SIZE = 1024 CHUNK_SIZE = 1024
PROGRESS_WINDOW = 5 PROGRESS_WINDOW = 5
DOWNLOAD_SIZES = []
LAST_SPEED_REFRESH = time.time()
def download( def download(
url: str, url: str,
save_path: Path, save_path: Path,
session: Optional[Session] = None, session: Optional[Session] = None,
segmented: bool = False,
**kwargs: Any **kwargs: Any
) -> Generator[dict[str, Any], None, None]: ) -> Generator[dict[str, Any], None, None]:
""" """
@ -48,10 +51,13 @@ def download(
session: The Requests Session to make HTTP requests with. Useful to set Header, session: The Requests Session to make HTTP requests with. Useful to set Header,
Cookie, and Proxy data. Connections are saved and re-used with the session Cookie, and Proxy data. Connections are saved and re-used with the session
so long as the server keeps the connection alive. so long as the server keeps the connection alive.
segmented: If downloads are segments or parts of one bigger file.
kwargs: Any extra keyword arguments to pass to the session.get() call. Use this kwargs: Any extra keyword arguments to pass to the session.get() call. Use this
for one-time request changes like a header, cookie, or proxy. For example, for one-time request changes like a header, cookie, or proxy. For example,
to request Byte-ranges use e.g., `headers={"Range": "bytes=0-128"}`. to request Byte-ranges use e.g., `headers={"Range": "bytes=0-128"}`.
""" """
global LAST_SPEED_REFRESH
session = session or Session() session = session or Session()
save_dir = save_path.parent save_dir = save_path.parent
@ -69,6 +75,7 @@ def download(
file_downloaded=save_path, file_downloaded=save_path,
written=save_path.stat().st_size written=save_path.stat().st_size
) )
# TODO: This should return, potential recovery bug
# TODO: Design a control file format so we know how much of the file is missing # TODO: Design a control file format so we know how much of the file is missing
control_file.write_bytes(b"") control_file.write_bytes(b"")
@ -77,6 +84,8 @@ def download(
try: try:
while True: while True:
written = 0 written = 0
# these are for single-url speed calcs only
download_sizes = [] download_sizes = []
last_speed_refresh = time.time() last_speed_refresh = time.time()
@ -84,16 +93,17 @@ def download(
stream = session.get(url, stream=True, **kwargs) stream = session.get(url, stream=True, **kwargs)
stream.raise_for_status() stream.raise_for_status()
try: if not segmented:
content_length = int(stream.headers.get("Content-Length", "0")) try:
except ValueError: content_length = int(stream.headers.get("Content-Length", "0"))
content_length = 0 except ValueError:
content_length = 0
if content_length > 0: if content_length > 0:
yield dict(total=math.ceil(content_length / CHUNK_SIZE)) yield dict(total=math.ceil(content_length / CHUNK_SIZE))
else: else:
# we have no data to calculate total chunks # we have no data to calculate total chunks
yield dict(total=None) # indeterminate mode yield dict(total=None) # indeterminate mode
with open(save_path, "wb") as f: with open(save_path, "wb") as f:
for chunk in stream.iter_content(chunk_size=CHUNK_SIZE): for chunk in stream.iter_content(chunk_size=CHUNK_SIZE):
@ -101,23 +111,32 @@ def download(
f.write(chunk) f.write(chunk)
written += download_size written += download_size
yield dict(advance=1) if not segmented:
yield dict(advance=1)
now = time.time()
time_since = now - last_speed_refresh
download_sizes.append(download_size)
if time_since > PROGRESS_WINDOW or download_size < CHUNK_SIZE:
data_size = sum(download_sizes)
download_speed = math.ceil(data_size / (time_since or 1))
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
now = time.time() yield dict(file_downloaded=save_path, written=written)
time_since = now - last_speed_refresh
download_sizes.append(download_size) if segmented:
if time_since > PROGRESS_WINDOW or download_size < CHUNK_SIZE: yield dict(advance=1)
data_size = sum(download_sizes) now = time.time()
download_speed = math.ceil(data_size / (time_since or 1)) time_since = now - LAST_SPEED_REFRESH
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s") if written: # no size == skipped dl
last_speed_refresh = now DOWNLOAD_SIZES.append(written)
download_sizes.clear() if DOWNLOAD_SIZES and time_since > PROGRESS_WINDOW:
data_size = sum(DOWNLOAD_SIZES)
yield dict( download_speed = math.ceil(data_size / (time_since or 1))
file_downloaded=save_path, yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
written=written LAST_SPEED_REFRESH = now
) DOWNLOAD_SIZES.clear()
break break
except Exception as e: except Exception as e:
save_path.unlink(missing_ok=True) save_path.unlink(missing_ok=True)
@ -237,59 +256,37 @@ def requests(
yield dict(total=len(urls)) yield dict(total=len(urls))
download_sizes = [] try:
last_speed_refresh = time.time() with ThreadPoolExecutor(max_workers=max_workers) as pool:
for future in as_completed(
with ThreadPoolExecutor(max_workers=max_workers) as pool: pool.submit(
for i, future in enumerate(futures.as_completed(( download,
pool.submit( session=session,
download, segmented=True,
session=session, **url
**url )
) for url in urls
for url in urls ):
))): try:
file_path, download_size = None, None yield from future.result()
try: except KeyboardInterrupt:
for status_update in future.result(): DOWNLOAD_CANCELLED.set() # skip pending track downloads
if status_update.get("file_downloaded") and status_update.get("written"): yield dict(downloaded="[yellow]CANCELLING")
file_path = status_update["file_downloaded"] pool.shutdown(wait=True, cancel_futures=True)
download_size = status_update["written"] yield dict(downloaded="[yellow]CANCELLED")
elif len(urls) == 1: # tell dl that it was cancelled
# these are per-chunk updates, only useful if it's one big file # the pool is already shut down, so exiting loop is fine
yield status_update raise
except KeyboardInterrupt: except Exception:
DOWNLOAD_CANCELLED.set() # skip pending track downloads DOWNLOAD_CANCELLED.set() # skip pending track downloads
yield dict(downloaded="[yellow]CANCELLING") yield dict(downloaded="[red]FAILING")
pool.shutdown(wait=True, cancel_futures=True) pool.shutdown(wait=True, cancel_futures=True)
yield dict(downloaded="[yellow]CANCELLED") yield dict(downloaded="[red]FAILED")
# tell dl that it was cancelled # tell dl that it failed
# the pool is already shut down, so exiting loop is fine # the pool is already shut down, so exiting loop is fine
raise raise
except Exception: finally:
DOWNLOAD_CANCELLED.set() # skip pending track downloads DOWNLOAD_SIZES.clear()
yield dict(downloaded="[red]FAILING")
pool.shutdown(wait=True, cancel_futures=True)
yield dict(downloaded="[red]FAILED")
# tell dl that it failed
# the pool is already shut down, so exiting loop is fine
raise
else:
yield dict(file_downloaded=file_path, written=download_size)
yield dict(advance=1)
now = time.time()
time_since = now - last_speed_refresh
if download_size: # no size == skipped dl
download_sizes.append(download_size)
if download_sizes and (time_since > PROGRESS_WINDOW or i == len(urls)):
data_size = sum(download_sizes)
download_speed = math.ceil(data_size / (time_since or 1))
yield dict(downloaded=f"{filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
__all__ = ("requests",) __all__ = ("requests",)