Multi-thread the new HLS download system

This mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs.
This commit is contained in:
rlaphoenix 2023-02-21 16:16:12 +00:00
parent 314079c75f
commit 9e6f5b25f3
1 changed files with 63 additions and 22 deletions

View File

@ -4,8 +4,14 @@ import asyncio
import logging import logging
import re import re
import sys import sys
import time
import traceback
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from queue import Queue
from threading import Event
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import m3u8 import m3u8
@ -205,21 +211,17 @@ class HLS:
log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.") log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
sys.exit(1) sys.exit(1)
init_data = None state_event = Event()
last_segment_key: tuple[Optional[Union[ClearKey, Widevine]], Optional[m3u8.Key]] = (None, None)
for i, segment in enumerate(tqdm(master.segments, unit="segments")): def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue):
segment_filename = str(i).zfill(len(str(len(master.segments)))) time.sleep(0.1)
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4") if state_event.is_set():
return
if segment.key and last_segment_key[1] != segment.key: segment_save_path = (save_dir / filename).with_suffix(".mp4")
# try:
# drm = HLS.get_drm([segment.key]) newest_segment_key = segment_key.get()
# except NotImplementedError: if segment.key and newest_segment_key[1] != segment.key:
# drm = None # never mind, try with master.keys
# if not drm and master.keys:
# # TODO: segment might have multiple keys but m3u8 only grabs the last!
# drm = HLS.get_drm(master.keys)
try: try:
drm = HLS.get_drm( drm = HLS.get_drm(
# TODO: We append master.keys because m3u8 class only puts the last EXT-X-KEY # TODO: We append master.keys because m3u8 class only puts the last EXT-X-KEY
@ -242,12 +244,14 @@ class HLS:
if not license_widevine: if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM") raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm) license_widevine(drm)
last_segment_key = (drm, segment.key) newest_segment_key = (drm, segment.key)
segment_key.put(newest_segment_key)
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment): if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
continue return
if segment.init_section and (not init_data or segment.discontinuity): newest_init_data = init_data.get()
if segment.init_section and (not newest_init_data or segment.discontinuity):
# Only use the init data if there's no init data yet (e.g., start of file) # Only use the init data if there's no init data yet (e.g., start of file)
# or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP. # or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
# Even if a new EXT-X-MAP is supplied, it may just be duplicate and would # Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
@ -258,7 +262,8 @@ class HLS:
log.debug("Got new init segment, %s", segment.init_section.uri) log.debug("Got new init segment, %s", segment.init_section.uri)
res = session.get(segment.init_section.uri) res = session.get(segment.init_section.uri)
res.raise_for_status() res.raise_for_status()
init_data = res.content newest_init_data = res.content
init_data.put(newest_init_data)
if not segment.uri.startswith(segment.base_uri): if not segment.uri.startswith(segment.base_uri):
segment.uri = segment.base_uri + segment.uri segment.uri = segment.base_uri + segment.uri
@ -270,7 +275,7 @@ class HLS:
proxy proxy
)) ))
if isinstance(track, Audio) or init_data: if isinstance(track, Audio) or newest_init_data:
with open(segment_save_path, "rb+") as f: with open(segment_save_path, "rb+") as f:
segment_data = f.read() segment_data = f.read()
if isinstance(track, Audio): if isinstance(track, Audio):
@ -282,17 +287,53 @@ class HLS:
segment_data segment_data
) )
# prepend the init data to be able to decrypt # prepend the init data to be able to decrypt
if init_data: if newest_init_data:
f.seek(0) f.seek(0)
f.write(init_data) f.write(newest_init_data)
f.write(segment_data) f.write(segment_data)
if last_segment_key[0]: if newest_segment_key[0]:
last_segment_key[0].decrypt(segment_save_path) newest_segment_key[0].decrypt(segment_save_path)
track.drm = None track.drm = None
if callable(track.OnDecrypted): if callable(track.OnDecrypted):
track.OnDecrypted(track) track.OnDecrypted(track)
init_data = Queue(maxsize=1)
segment_key = Queue(maxsize=1)
# otherwise will be stuck waiting on the first pool, forever
init_data.put(None)
segment_key.put((None, None))
with tqdm(total=len(master.segments), unit="segments") as pbar:
with ThreadPoolExecutor(max_workers=16) as pool:
try:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(master.segments)))),
segment=segment,
init_data=init_data,
segment_key=segment_key
)
for i, segment in enumerate(master.segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
pbar.update(1)
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
log.info("Received Keyboard Interrupt, stopping...")
return
@staticmethod @staticmethod
def get_drm( def get_drm(
keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]], keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]],