From 9e6f5b25f39d0f56e5099793bf4717389436a3fc Mon Sep 17 00:00:00 2001 From: rlaphoenix Date: Tue, 21 Feb 2023 16:16:12 +0000 Subject: [PATCH] Multi-thread the new HLS download system This mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs. --- devine/core/manifests/hls.py | 85 ++++++++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 22 deletions(-) diff --git a/devine/core/manifests/hls.py b/devine/core/manifests/hls.py index a518de7..71391b7 100644 --- a/devine/core/manifests/hls.py +++ b/devine/core/manifests/hls.py @@ -4,8 +4,14 @@ import asyncio import logging import re import sys +import time +import traceback +from concurrent import futures +from concurrent.futures import ThreadPoolExecutor from hashlib import md5 from pathlib import Path +from queue import Queue +from threading import Event from typing import Any, Callable, Optional, Union import m3u8 @@ -205,21 +211,17 @@ class HLS: log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.") sys.exit(1) - init_data = None - last_segment_key: tuple[Optional[Union[ClearKey, Widevine]], Optional[m3u8.Key]] = (None, None) + state_event = Event() - for i, segment in enumerate(tqdm(master.segments, unit="segments")): - segment_filename = str(i).zfill(len(str(len(master.segments)))) - segment_save_path = (save_dir / segment_filename).with_suffix(".mp4") + def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue): + time.sleep(0.1) + if state_event.is_set(): + return - if segment.key and last_segment_key[1] != segment.key: - # try: - # drm = HLS.get_drm([segment.key]) - # except NotImplementedError: - # drm = None # never mind, try with master.keys - # if not drm and master.keys: - # # TODO: segment might have multiple keys but m3u8 only grabs the last! - # drm = HLS.get_drm(master.keys) + segment_save_path = (save_dir / filename).with_suffix(".mp4") + + newest_segment_key = segment_key.get() + if segment.key and newest_segment_key[1] != segment.key: try: drm = HLS.get_drm( # TODO: We append master.keys because m3u8 class only puts the last EXT-X-KEY @@ -242,12 +244,14 @@ class HLS: if not license_widevine: raise ValueError("license_widevine func must be supplied to use Widevine DRM") license_widevine(drm) - last_segment_key = (drm, segment.key) + newest_segment_key = (drm, segment.key) + segment_key.put(newest_segment_key) if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment): - continue + return - if segment.init_section and (not init_data or segment.discontinuity): + newest_init_data = init_data.get() + if segment.init_section and (not newest_init_data or segment.discontinuity): # Only use the init data if there's no init data yet (e.g., start of file) # or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP. # Even if a new EXT-X-MAP is supplied, it may just be duplicate and would @@ -258,7 +262,8 @@ class HLS: log.debug("Got new init segment, %s", segment.init_section.uri) res = session.get(segment.init_section.uri) res.raise_for_status() - init_data = res.content + newest_init_data = res.content + init_data.put(newest_init_data) if not segment.uri.startswith(segment.base_uri): segment.uri = segment.base_uri + segment.uri @@ -270,7 +275,7 @@ class HLS: proxy )) - if isinstance(track, Audio) or init_data: + if isinstance(track, Audio) or newest_init_data: with open(segment_save_path, "rb+") as f: segment_data = f.read() if isinstance(track, Audio): @@ -282,17 +287,53 @@ class HLS: segment_data ) # prepend the init data to be able to decrypt - if init_data: + if newest_init_data: f.seek(0) - f.write(init_data) + f.write(newest_init_data) f.write(segment_data) - if last_segment_key[0]: - last_segment_key[0].decrypt(segment_save_path) + if newest_segment_key[0]: + newest_segment_key[0].decrypt(segment_save_path) track.drm = None if callable(track.OnDecrypted): track.OnDecrypted(track) + init_data = Queue(maxsize=1) + segment_key = Queue(maxsize=1) + # otherwise will be stuck waiting on the first pool, forever + init_data.put(None) + segment_key.put((None, None)) + + with tqdm(total=len(master.segments), unit="segments") as pbar: + with ThreadPoolExecutor(max_workers=16) as pool: + try: + for download in futures.as_completed(( + pool.submit( + download_segment, + filename=str(i).zfill(len(str(len(master.segments)))), + segment=segment, + init_data=init_data, + segment_key=segment_key + ) + for i, segment in enumerate(master.segments) + )): + if download.cancelled(): + continue + e = download.exception() + if e: + state_event.set() + pool.shutdown(wait=False, cancel_futures=True) + traceback.print_exception(e) + log.error(f"Segment Download worker threw an unhandled exception: {e!r}") + sys.exit(1) + else: + pbar.update(1) + except KeyboardInterrupt: + state_event.set() + pool.shutdown(wait=False, cancel_futures=True) + log.info("Received Keyboard Interrupt, stopping...") + return + @staticmethod def get_drm( keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]],