Multi-thread the new HLS download system

This mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs.
This commit is contained in:
rlaphoenix 2023-02-21 16:16:12 +00:00
parent 314079c75f
commit 9e6f5b25f3
1 changed files with 63 additions and 22 deletions

View File

@ -4,8 +4,14 @@ import asyncio
import logging
import re
import sys
import time
import traceback
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from hashlib import md5
from pathlib import Path
from queue import Queue
from threading import Event
from typing import Any, Callable, Optional, Union
import m3u8
@ -205,21 +211,17 @@ class HLS:
log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
sys.exit(1)
init_data = None
last_segment_key: tuple[Optional[Union[ClearKey, Widevine]], Optional[m3u8.Key]] = (None, None)
state_event = Event()
for i, segment in enumerate(tqdm(master.segments, unit="segments")):
segment_filename = str(i).zfill(len(str(len(master.segments))))
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue):
time.sleep(0.1)
if state_event.is_set():
return
if segment.key and last_segment_key[1] != segment.key:
# try:
# drm = HLS.get_drm([segment.key])
# except NotImplementedError:
# drm = None # never mind, try with master.keys
# if not drm and master.keys:
# # TODO: segment might have multiple keys but m3u8 only grabs the last!
# drm = HLS.get_drm(master.keys)
segment_save_path = (save_dir / filename).with_suffix(".mp4")
newest_segment_key = segment_key.get()
if segment.key and newest_segment_key[1] != segment.key:
try:
drm = HLS.get_drm(
# TODO: We append master.keys because m3u8 class only puts the last EXT-X-KEY
@ -242,12 +244,14 @@ class HLS:
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm)
last_segment_key = (drm, segment.key)
newest_segment_key = (drm, segment.key)
segment_key.put(newest_segment_key)
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
continue
return
if segment.init_section and (not init_data or segment.discontinuity):
newest_init_data = init_data.get()
if segment.init_section and (not newest_init_data or segment.discontinuity):
# Only use the init data if there's no init data yet (e.g., start of file)
# or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
# Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
@ -258,7 +262,8 @@ class HLS:
log.debug("Got new init segment, %s", segment.init_section.uri)
res = session.get(segment.init_section.uri)
res.raise_for_status()
init_data = res.content
newest_init_data = res.content
init_data.put(newest_init_data)
if not segment.uri.startswith(segment.base_uri):
segment.uri = segment.base_uri + segment.uri
@ -270,7 +275,7 @@ class HLS:
proxy
))
if isinstance(track, Audio) or init_data:
if isinstance(track, Audio) or newest_init_data:
with open(segment_save_path, "rb+") as f:
segment_data = f.read()
if isinstance(track, Audio):
@ -282,17 +287,53 @@ class HLS:
segment_data
)
# prepend the init data to be able to decrypt
if init_data:
if newest_init_data:
f.seek(0)
f.write(init_data)
f.write(newest_init_data)
f.write(segment_data)
if last_segment_key[0]:
last_segment_key[0].decrypt(segment_save_path)
if newest_segment_key[0]:
newest_segment_key[0].decrypt(segment_save_path)
track.drm = None
if callable(track.OnDecrypted):
track.OnDecrypted(track)
init_data = Queue(maxsize=1)
segment_key = Queue(maxsize=1)
# otherwise will be stuck waiting on the first pool, forever
init_data.put(None)
segment_key.put((None, None))
with tqdm(total=len(master.segments), unit="segments") as pbar:
with ThreadPoolExecutor(max_workers=16) as pool:
try:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(master.segments)))),
segment=segment,
init_data=init_data,
segment_key=segment_key
)
for i, segment in enumerate(master.segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
pbar.update(1)
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
log.info("Received Keyboard Interrupt, stopping...")
return
@staticmethod
def get_drm(
keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]],