From a1ed083b74227903f3519706223180b06aefd2dd Mon Sep 17 00:00:00 2001 From: rlaphoenix Date: Thu, 15 Feb 2024 11:13:14 +0000 Subject: [PATCH] Add support for the new Downloaders to DASH --- devine/core/manifests/dash.py | 152 +++++++--------------------------- 1 file changed, 30 insertions(+), 122 deletions(-) diff --git a/devine/core/manifests/dash.py b/devine/core/manifests/dash.py index f0d3680..af09857 100644 --- a/devine/core/manifests/dash.py +++ b/devine/core/manifests/dash.py @@ -6,13 +6,10 @@ import logging import math import re import sys -import time -from concurrent import futures -from concurrent.futures import ThreadPoolExecutor from copy import copy from functools import partial from pathlib import Path -from typing import Any, Callable, MutableMapping, Optional, Union +from typing import Any, Callable, Optional, Union from urllib.parse import urljoin, urlparse from uuid import UUID from zlib import crc32 @@ -23,8 +20,6 @@ from lxml.etree import Element from pywidevine.cdm import Cdm as WidevineCdm from pywidevine.pssh import PSSH from requests import Session -from requests.cookies import RequestsCookieJar -from rich import filesize from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack from devine.core.downloaders import downloader @@ -435,57 +430,36 @@ class DASH: progress(total=len(segments)) - download_sizes = [] - download_speed_window = 5 - last_speed_refresh = time.time() + downloader_ = downloader + if downloader.__name__ == "aria2c" and any(bytes_range is not None for url, bytes_range in segments): + # aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader + downloader_ = requests_downloader - with ThreadPoolExecutor(max_workers=16) as pool: - for i, download in enumerate(futures.as_completed(( - pool.submit( - DASH.download_segment, - url=url, - out_path=(save_dir / str(n).zfill(len(str(len(segments))))).with_suffix(".mp4"), - track=track, - proxy=proxy, - headers=session.headers, - cookies=session.cookies, - bytes_range=bytes_range - ) - for n, (url, bytes_range) in enumerate(segments) - ))): - try: - download_size = download.result() - except KeyboardInterrupt: - DOWNLOAD_CANCELLED.set() # skip pending track downloads - progress(downloaded="[yellow]CANCELLING") - pool.shutdown(wait=True, cancel_futures=True) - progress(downloaded="[yellow]CANCELLED") - # tell dl that it was cancelled - # the pool is already shut down, so exiting loop is fine - raise - except Exception: - DOWNLOAD_CANCELLED.set() # skip pending track downloads - progress(downloaded="[red]FAILING") - pool.shutdown(wait=True, cancel_futures=True) - progress(downloaded="[red]FAILED") - # tell dl that it failed - # the pool is already shut down, so exiting loop is fine - raise - else: - progress(advance=1) - - now = time.time() - time_since = now - last_speed_refresh - - if download_size: # no size == skipped dl - download_sizes.append(download_size) - - if download_sizes and (time_since > download_speed_window or i == len(segments)): - data_size = sum(download_sizes) - download_speed = data_size / (time_since or 1) - progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s") - last_speed_refresh = now - download_sizes.clear() + for status_update in downloader_( + urls=[ + { + "url": url, + "headers": { + "Range": f"bytes={bytes_range}" + } + } + for url, bytes_range in segments + ], + output_dir=save_dir, + filename="{i:0%d}.mp4" % (len(str(len(segments)))), + headers=session.headers, + cookies=session.cookies, + proxy=proxy, + max_workers=16 + ): + file_downloaded = status_update.get("file_downloaded") + if file_downloaded and callable(track.OnSegmentDownloaded): + track.OnSegmentDownloaded(file_downloaded) + else: + downloaded = status_update.get("downloaded") + if downloaded and downloaded.endswith("/s"): + status_update["downloaded"] = f"DASH {downloaded}" + progress(**status_update) with open(save_path, "wb") as f: if init_data: @@ -518,72 +492,6 @@ class DASH: progress(downloaded="Downloaded") - @staticmethod - def download_segment( - url: str, - out_path: Path, - track: AnyTrack, - proxy: Optional[str] = None, - headers: Optional[MutableMapping[str, str | bytes]] = None, - cookies: Optional[Union[MutableMapping[str, str], RequestsCookieJar]] = None, - bytes_range: Optional[str] = None - ) -> int: - """ - Download a DASH Media Segment. - - Parameters: - url: Full HTTP(S) URL to the Segment you want to download. - out_path: Path to save the downloaded Segment file to. - track: The Track object of which this Segment is for. Currently only used to - fix an invalid value in the TFHD box of Audio Tracks. - proxy: Proxy URI to use when downloading the Segment file. - headers: HTTP Headers to send when requesting the Segment file. - cookies: Cookies to send when requesting the Segment file. The actual cookies sent - will be resolved based on the URI among other parameters. Multiple cookies with - the same name but a different domain/path are resolved. - bytes_range: Download only specific bytes of the Segment file using the Range header. - - Returns the file size of the downloaded Segment in bytes. - """ - if DOWNLOAD_CANCELLED.is_set(): - raise KeyboardInterrupt() - - if bytes_range: - # aria2(c) doesn't support byte ranges, use python-requests - downloader_ = requests_downloader - headers_ = dict(**headers, Range=f"bytes={bytes_range}") - else: - downloader_ = downloader - headers_ = headers - - downloader_( - uri=url, - out=out_path, - headers=headers_, - cookies=cookies, - proxy=proxy, - segmented=True - ) - - if callable(track.OnSegmentDownloaded): - track.OnSegmentDownloaded(out_path) - - # fix audio decryption on ATVP by fixing the sample description index - # TODO: Should this be done in the video data or the init data? - if isinstance(track, Audio): - with open(out_path, "rb+") as f: - segment_data = f.read() - fixed_segment_data = re.sub( - b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02", - b"\\g<1>\x01", - segment_data - ) - if fixed_segment_data != segment_data: - f.seek(0) - f.write(fixed_segment_data) - - return out_path.stat().st_size - @staticmethod def _get( item: str,