From 4e875f5ffca19da44da2d8c16b5e7d166c382bf3 Mon Sep 17 00:00:00 2001 From: rlaphoenix Date: Wed, 22 Feb 2023 01:06:03 +0000 Subject: [PATCH] Multi-thread the new DASH download system, improve redundency Just like the commit for HLS multi-threading, this mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs. --- devine/core/manifests/dash.py | 314 +++++++++++++++++----------------- 1 file changed, 156 insertions(+), 158 deletions(-) diff --git a/devine/core/manifests/dash.py b/devine/core/manifests/dash.py index 5bb64ac..97fa555 100644 --- a/devine/core/manifests/dash.py +++ b/devine/core/manifests/dash.py @@ -6,9 +6,14 @@ import logging import math import re import sys +import time +import traceback +from concurrent import futures +from concurrent.futures import ThreadPoolExecutor from copy import copy from hashlib import md5 from pathlib import Path +from threading import Event from typing import Any, Callable, Optional, Union from urllib.parse import urljoin, urlparse from uuid import UUID @@ -303,7 +308,6 @@ class DASH: else: drm = None - segment_urls: list[str] = [] manifest = load_xml(session.get(manifest_url).text) manifest_url_query = urlparse(manifest_url).query @@ -312,107 +316,151 @@ class DASH: period_base_url = urljoin(manifest_url, period_base_url) period_duration = period.get("duration") or manifest.get("mediaPresentationDuration") + init_data: Optional[bytes] = None base_url = representation.findtext("BaseURL") or period_base_url segment_template = representation.find("SegmentTemplate") if segment_template is None: segment_template = adaptation_set.find("SegmentTemplate") - segment_base = representation.find("SegmentBase") - if segment_base is None: - segment_base = adaptation_set.find("SegmentBase") - segment_list = representation.find("SegmentList") if segment_list is None: segment_list = adaptation_set.find("SegmentList") - if segment_template is not None: - segment_template = copy(segment_template) - start_number = int(segment_template.get("startNumber") or 1) - segment_timeline = segment_template.find("SegmentTimeline") + if segment_template is None and segment_list is None and base_url: + # If there's no SegmentTemplate and no SegmentList, then SegmentBase is used or just BaseURL + # Regardless which of the two is used, we can just directly grab the BaseURL + # Players would normally calculate segments via Byte-Ranges, but we don't care + track.url = urljoin(period_base_url, base_url) + track.descriptor = track.Descriptor.URL + track.drm = [drm] if drm else [] + else: + segments: list[tuple[str, Optional[str]]] = [] - for item in ("initialization", "media"): - value = segment_template.get(item) - if not value: - continue - if not re.match("^https?://", value, re.IGNORECASE): - if not base_url: - raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.") - value = urljoin(base_url, value) - if not urlparse(value).query and manifest_url_query: - value += f"?{manifest_url_query}" - segment_template.set(item, value) + if segment_template is not None: + segment_template = copy(segment_template) + start_number = int(segment_template.get("startNumber") or 1) + segment_timeline = segment_template.find("SegmentTimeline") - if segment_timeline is not None: - seg_time_list = [] - current_time = 0 - for s in segment_timeline.findall("S"): - if s.get("t"): - current_time = int(s.get("t")) - for _ in range(1 + (int(s.get("r") or 0))): - seg_time_list.append(current_time) - current_time += int(s.get("d")) - seg_num_list = list(range(start_number, len(seg_time_list) + start_number)) - segment_urls += [ - DASH.replace_fields( - segment_template.get("media"), + for item in ("initialization", "media"): + value = segment_template.get(item) + if not value: + continue + if not re.match("^https?://", value, re.IGNORECASE): + if not base_url: + raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.") + value = urljoin(base_url, value) + if not urlparse(value).query and manifest_url_query: + value += f"?{manifest_url_query}" + segment_template.set(item, value) + + init_url = segment_template.get("initialization") + if init_url: + res = session.get(DASH.replace_fields( + init_url, Bandwidth=representation.get("bandwidth"), - Number=n, - RepresentationID=representation.get("id"), - Time=t - ) - for t, n in zip(seg_time_list, seg_num_list) - ] + RepresentationID=representation.get("id") + )) + res.raise_for_status() + init_data = res.content + + if segment_timeline is not None: + seg_time_list = [] + current_time = 0 + for s in segment_timeline.findall("S"): + if s.get("t"): + current_time = int(s.get("t")) + for _ in range(1 + (int(s.get("r") or 0))): + seg_time_list.append(current_time) + current_time += int(s.get("d")) + seg_num_list = list(range(start_number, len(seg_time_list) + start_number)) + + for t, n in zip(seg_time_list, seg_num_list): + segments.append(( + DASH.replace_fields( + segment_template.get("media"), + Bandwidth=representation.get("bandwidth"), + Number=n, + RepresentationID=representation.get("id"), + Time=t + ), None + )) + else: + if not period_duration: + raise ValueError("Duration of the Period was unable to be determined.") + period_duration = DASH.pt_to_sec(period_duration) + segment_duration = float(segment_template.get("duration")) + segment_timescale = float(segment_template.get("timescale") or 1) + total_segments = math.ceil(period_duration / (segment_duration / segment_timescale)) + + for s in range(start_number, start_number + total_segments): + segments.append(( + DASH.replace_fields( + segment_template.get("media"), + Bandwidth=representation.get("bandwidth"), + Number=s, + RepresentationID=representation.get("id"), + Time=s + ), None + )) + elif segment_list is not None: + base_media_url = urljoin(period_base_url, base_url) + + init_data = None + initialization = segment_list.find("Initialization") + if initialization: + source_url = initialization.get("sourceURL") + if source_url is None: + source_url = base_media_url + + res = session.get(source_url) + res.raise_for_status() + init_data = res.content + + segment_urls = segment_list.findall("SegmentURL") + for segment_url in segment_urls: + media_url = segment_url.get("media") + if media_url is None: + media_url = base_media_url + + segments.append(( + media_url, + segment_url.get("mediaRange") + )) else: - if not period_duration: - raise ValueError("Duration of the Period was unable to be determined.") - period_duration = DASH.pt_to_sec(period_duration) - segment_duration = float(segment_template.get("duration")) - segment_timescale = float(segment_template.get("timescale") or 1) + log.error("Could not find a way to get segments from this MPD manifest.") + log.debug(manifest_url) + sys.exit(1) - total_segments = math.ceil(period_duration / (segment_duration / segment_timescale)) - segment_urls += [ - DASH.replace_fields( - segment_template.get("media"), - Bandwidth=representation.get("bandwidth"), - Number=s, - RepresentationID=representation.get("id"), - Time=s - ) - for s in range(start_number, start_number + total_segments) - ] + if not drm and isinstance(track, (Video, Audio)): + try: + drm = Widevine.from_init_data(init_data) + except Widevine.Exceptions.PSSHNotFound: + # it might not have Widevine DRM, or might not have found the PSSH + log.warning("No Widevine PSSH was found for this track, is it DRM free?") + else: + # license and grab content keys + if not license_widevine: + raise ValueError("license_widevine func must be supplied to use Widevine DRM") + license_widevine(drm) - init_data = None - init_url = segment_template.get("initialization") - if init_url: - res = session.get(DASH.replace_fields( - init_url, - Bandwidth=representation.get("bandwidth"), - RepresentationID=representation.get("id") - )) - res.raise_for_status() - init_data = res.content - if not drm: - try: - drm = Widevine.from_init_data(init_data) - except Widevine.Exceptions.PSSHNotFound: - # it might not have Widevine DRM, or might not have found the PSSH - log.warning("No Widevine PSSH was found for this track, is it DRM free?") - else: - # license and grab content keys - if not license_widevine: - raise ValueError("license_widevine func must be supplied to use Widevine DRM") - license_widevine(drm) + state_event = Event() - for i, segment_url in enumerate(tqdm(segment_urls, unit="segments")): - segment_filename = str(i).zfill(len(str(len(segment_urls)))) - segment_save_path = (save_dir / segment_filename).with_suffix(".mp4") + def download_segment(filename: str, segment: tuple[str, Optional[str]]): + time.sleep(0.1) + if state_event.is_set(): + return + + segment_save_path = (save_dir / filename).with_suffix(".mp4") + + segment_uri, segment_range = segment asyncio.run(aria2c( - segment_url, + segment_uri, segment_save_path, session.headers, - proxy + proxy, + byte_range=segment_range )) if isinstance(track, Audio) or init_data: @@ -438,84 +486,34 @@ class DASH: track.drm = None if callable(track.OnDecrypted): track.OnDecrypted(track) - elif segment_list is not None: - base_media_url = urljoin(period_base_url, base_url) - if any(x.get("media") is not None for x in segment_list.findall("SegmentURL")): - # at least one segment has no URL specified, it uses the base url and ranges - track.url = base_media_url - track.descriptor = track.Descriptor.URL - track.drm = [drm] if drm else [] - else: - init_data = None - initialization = segment_list.find("Initialization") - if initialization: - source_url = initialization.get("sourceURL") - if source_url is None: - source_url = base_media_url - res = session.get(source_url) - res.raise_for_status() - init_data = res.content - if not drm: - try: - drm = Widevine.from_init_data(init_data) - except Widevine.Exceptions.PSSHNotFound: - # it might not have Widevine DRM, or might not have found the PSSH - log.warning("No Widevine PSSH was found for this track, is it DRM free?") - else: - # license and grab content keys - if not license_widevine: - raise ValueError("license_widevine func must be supplied to use Widevine DRM") - license_widevine(drm) - - for i, segment_url in enumerate(tqdm(segment_list.findall("SegmentURL"), unit="segments")): - segment_filename = str(i).zfill(len(str(len(segment_urls)))) - segment_save_path = (save_dir / segment_filename).with_suffix(".mp4") - - media_url = segment_url.get("media") - if media_url is None: - media_url = base_media_url - - asyncio.run(aria2c( - media_url, - segment_save_path, - session.headers, - proxy, - byte_range=segment_url.get("mediaRange") - )) - - if isinstance(track, Audio) or init_data: - with open(segment_save_path, "rb+") as f: - segment_data = f.read() - if isinstance(track, Audio): - # fix audio decryption on ATVP by fixing the sample description index - # TODO: Is this in mpeg data, or init data? - segment_data = re.sub( - b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02", - b"\\g<1>\x01", - segment_data - ) - # prepend the init data to be able to decrypt - if init_data: - f.seek(0) - f.write(init_data) - f.write(segment_data) - - if drm: - # TODO: What if the manifest does not mention DRM, but has DRM - drm.decrypt(segment_save_path) - track.drm = None - if callable(track.OnDecrypted): - track.OnDecrypted(track) - elif segment_base is not None or base_url: - # SegmentBase more or less boils down to defined ByteRanges - # So, we don't care, just download the full file - track.url = urljoin(period_base_url, base_url) - track.descriptor = track.Descriptor.URL - track.drm = [drm] if drm else [] - else: - log.error("Could not find a way to get segments from this MPD manifest.") - sys.exit(1) + with tqdm(total=len(segments), unit="segments") as pbar: + with ThreadPoolExecutor(max_workers=16) as pool: + try: + for download in futures.as_completed(( + pool.submit( + download_segment, + filename=str(i).zfill(len(str(len(segments)))), + segment=segment + ) + for i, segment in enumerate(segments) + )): + if download.cancelled(): + continue + e = download.exception() + if e: + state_event.set() + pool.shutdown(wait=False, cancel_futures=True) + traceback.print_exception(e) + log.error(f"Segment Download worker threw an unhandled exception: {e!r}") + sys.exit(1) + else: + pbar.update(1) + except KeyboardInterrupt: + state_event.set() + pool.shutdown(wait=False, cancel_futures=True) + log.info("Received Keyboard Interrupt, stopping...") + return @staticmethod def get_language(*options: Any) -> Optional[Language]: