forked from DRMTalks/devine
Multi-thread the new HLS download system
This mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs.
This commit is contained in:
parent
314079c75f
commit
9e6f5b25f3
|
@ -4,8 +4,14 @@ import asyncio
|
|||
import logging
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from concurrent import futures
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
from threading import Event
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import m3u8
|
||||
|
@ -205,21 +211,17 @@ class HLS:
|
|||
log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
|
||||
sys.exit(1)
|
||||
|
||||
init_data = None
|
||||
last_segment_key: tuple[Optional[Union[ClearKey, Widevine]], Optional[m3u8.Key]] = (None, None)
|
||||
state_event = Event()
|
||||
|
||||
for i, segment in enumerate(tqdm(master.segments, unit="segments")):
|
||||
segment_filename = str(i).zfill(len(str(len(master.segments))))
|
||||
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
|
||||
def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue):
|
||||
time.sleep(0.1)
|
||||
if state_event.is_set():
|
||||
return
|
||||
|
||||
if segment.key and last_segment_key[1] != segment.key:
|
||||
# try:
|
||||
# drm = HLS.get_drm([segment.key])
|
||||
# except NotImplementedError:
|
||||
# drm = None # never mind, try with master.keys
|
||||
# if not drm and master.keys:
|
||||
# # TODO: segment might have multiple keys but m3u8 only grabs the last!
|
||||
# drm = HLS.get_drm(master.keys)
|
||||
segment_save_path = (save_dir / filename).with_suffix(".mp4")
|
||||
|
||||
newest_segment_key = segment_key.get()
|
||||
if segment.key and newest_segment_key[1] != segment.key:
|
||||
try:
|
||||
drm = HLS.get_drm(
|
||||
# TODO: We append master.keys because m3u8 class only puts the last EXT-X-KEY
|
||||
|
@ -242,12 +244,14 @@ class HLS:
|
|||
if not license_widevine:
|
||||
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
|
||||
license_widevine(drm)
|
||||
last_segment_key = (drm, segment.key)
|
||||
newest_segment_key = (drm, segment.key)
|
||||
segment_key.put(newest_segment_key)
|
||||
|
||||
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
|
||||
continue
|
||||
return
|
||||
|
||||
if segment.init_section and (not init_data or segment.discontinuity):
|
||||
newest_init_data = init_data.get()
|
||||
if segment.init_section and (not newest_init_data or segment.discontinuity):
|
||||
# Only use the init data if there's no init data yet (e.g., start of file)
|
||||
# or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
|
||||
# Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
|
||||
|
@ -258,7 +262,8 @@ class HLS:
|
|||
log.debug("Got new init segment, %s", segment.init_section.uri)
|
||||
res = session.get(segment.init_section.uri)
|
||||
res.raise_for_status()
|
||||
init_data = res.content
|
||||
newest_init_data = res.content
|
||||
init_data.put(newest_init_data)
|
||||
|
||||
if not segment.uri.startswith(segment.base_uri):
|
||||
segment.uri = segment.base_uri + segment.uri
|
||||
|
@ -270,7 +275,7 @@ class HLS:
|
|||
proxy
|
||||
))
|
||||
|
||||
if isinstance(track, Audio) or init_data:
|
||||
if isinstance(track, Audio) or newest_init_data:
|
||||
with open(segment_save_path, "rb+") as f:
|
||||
segment_data = f.read()
|
||||
if isinstance(track, Audio):
|
||||
|
@ -282,17 +287,53 @@ class HLS:
|
|||
segment_data
|
||||
)
|
||||
# prepend the init data to be able to decrypt
|
||||
if init_data:
|
||||
if newest_init_data:
|
||||
f.seek(0)
|
||||
f.write(init_data)
|
||||
f.write(newest_init_data)
|
||||
f.write(segment_data)
|
||||
|
||||
if last_segment_key[0]:
|
||||
last_segment_key[0].decrypt(segment_save_path)
|
||||
if newest_segment_key[0]:
|
||||
newest_segment_key[0].decrypt(segment_save_path)
|
||||
track.drm = None
|
||||
if callable(track.OnDecrypted):
|
||||
track.OnDecrypted(track)
|
||||
|
||||
init_data = Queue(maxsize=1)
|
||||
segment_key = Queue(maxsize=1)
|
||||
# otherwise will be stuck waiting on the first pool, forever
|
||||
init_data.put(None)
|
||||
segment_key.put((None, None))
|
||||
|
||||
with tqdm(total=len(master.segments), unit="segments") as pbar:
|
||||
with ThreadPoolExecutor(max_workers=16) as pool:
|
||||
try:
|
||||
for download in futures.as_completed((
|
||||
pool.submit(
|
||||
download_segment,
|
||||
filename=str(i).zfill(len(str(len(master.segments)))),
|
||||
segment=segment,
|
||||
init_data=init_data,
|
||||
segment_key=segment_key
|
||||
)
|
||||
for i, segment in enumerate(master.segments)
|
||||
)):
|
||||
if download.cancelled():
|
||||
continue
|
||||
e = download.exception()
|
||||
if e:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
traceback.print_exception(e)
|
||||
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
pbar.update(1)
|
||||
except KeyboardInterrupt:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
log.info("Received Keyboard Interrupt, stopping...")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_drm(
|
||||
keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]],
|
||||
|
|
Loading…
Reference in New Issue