forked from DRMTalks/devine
Multi-thread the new HLS download system
This mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs.
This commit is contained in:
parent
314079c75f
commit
9e6f5b25f3
|
@ -4,8 +4,14 @@ import asyncio
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from concurrent import futures
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from queue import Queue
|
||||||
|
from threading import Event
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import m3u8
|
import m3u8
|
||||||
|
@ -205,21 +211,17 @@ class HLS:
|
||||||
log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
|
log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
init_data = None
|
state_event = Event()
|
||||||
last_segment_key: tuple[Optional[Union[ClearKey, Widevine]], Optional[m3u8.Key]] = (None, None)
|
|
||||||
|
|
||||||
for i, segment in enumerate(tqdm(master.segments, unit="segments")):
|
def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue):
|
||||||
segment_filename = str(i).zfill(len(str(len(master.segments))))
|
time.sleep(0.1)
|
||||||
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
|
if state_event.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
if segment.key and last_segment_key[1] != segment.key:
|
segment_save_path = (save_dir / filename).with_suffix(".mp4")
|
||||||
# try:
|
|
||||||
# drm = HLS.get_drm([segment.key])
|
newest_segment_key = segment_key.get()
|
||||||
# except NotImplementedError:
|
if segment.key and newest_segment_key[1] != segment.key:
|
||||||
# drm = None # never mind, try with master.keys
|
|
||||||
# if not drm and master.keys:
|
|
||||||
# # TODO: segment might have multiple keys but m3u8 only grabs the last!
|
|
||||||
# drm = HLS.get_drm(master.keys)
|
|
||||||
try:
|
try:
|
||||||
drm = HLS.get_drm(
|
drm = HLS.get_drm(
|
||||||
# TODO: We append master.keys because m3u8 class only puts the last EXT-X-KEY
|
# TODO: We append master.keys because m3u8 class only puts the last EXT-X-KEY
|
||||||
|
@ -242,12 +244,14 @@ class HLS:
|
||||||
if not license_widevine:
|
if not license_widevine:
|
||||||
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
|
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
|
||||||
license_widevine(drm)
|
license_widevine(drm)
|
||||||
last_segment_key = (drm, segment.key)
|
newest_segment_key = (drm, segment.key)
|
||||||
|
segment_key.put(newest_segment_key)
|
||||||
|
|
||||||
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
|
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
|
||||||
continue
|
return
|
||||||
|
|
||||||
if segment.init_section and (not init_data or segment.discontinuity):
|
newest_init_data = init_data.get()
|
||||||
|
if segment.init_section and (not newest_init_data or segment.discontinuity):
|
||||||
# Only use the init data if there's no init data yet (e.g., start of file)
|
# Only use the init data if there's no init data yet (e.g., start of file)
|
||||||
# or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
|
# or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
|
||||||
# Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
|
# Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
|
||||||
|
@ -258,7 +262,8 @@ class HLS:
|
||||||
log.debug("Got new init segment, %s", segment.init_section.uri)
|
log.debug("Got new init segment, %s", segment.init_section.uri)
|
||||||
res = session.get(segment.init_section.uri)
|
res = session.get(segment.init_section.uri)
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
init_data = res.content
|
newest_init_data = res.content
|
||||||
|
init_data.put(newest_init_data)
|
||||||
|
|
||||||
if not segment.uri.startswith(segment.base_uri):
|
if not segment.uri.startswith(segment.base_uri):
|
||||||
segment.uri = segment.base_uri + segment.uri
|
segment.uri = segment.base_uri + segment.uri
|
||||||
|
@ -270,7 +275,7 @@ class HLS:
|
||||||
proxy
|
proxy
|
||||||
))
|
))
|
||||||
|
|
||||||
if isinstance(track, Audio) or init_data:
|
if isinstance(track, Audio) or newest_init_data:
|
||||||
with open(segment_save_path, "rb+") as f:
|
with open(segment_save_path, "rb+") as f:
|
||||||
segment_data = f.read()
|
segment_data = f.read()
|
||||||
if isinstance(track, Audio):
|
if isinstance(track, Audio):
|
||||||
|
@ -282,17 +287,53 @@ class HLS:
|
||||||
segment_data
|
segment_data
|
||||||
)
|
)
|
||||||
# prepend the init data to be able to decrypt
|
# prepend the init data to be able to decrypt
|
||||||
if init_data:
|
if newest_init_data:
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
f.write(init_data)
|
f.write(newest_init_data)
|
||||||
f.write(segment_data)
|
f.write(segment_data)
|
||||||
|
|
||||||
if last_segment_key[0]:
|
if newest_segment_key[0]:
|
||||||
last_segment_key[0].decrypt(segment_save_path)
|
newest_segment_key[0].decrypt(segment_save_path)
|
||||||
track.drm = None
|
track.drm = None
|
||||||
if callable(track.OnDecrypted):
|
if callable(track.OnDecrypted):
|
||||||
track.OnDecrypted(track)
|
track.OnDecrypted(track)
|
||||||
|
|
||||||
|
init_data = Queue(maxsize=1)
|
||||||
|
segment_key = Queue(maxsize=1)
|
||||||
|
# otherwise will be stuck waiting on the first pool, forever
|
||||||
|
init_data.put(None)
|
||||||
|
segment_key.put((None, None))
|
||||||
|
|
||||||
|
with tqdm(total=len(master.segments), unit="segments") as pbar:
|
||||||
|
with ThreadPoolExecutor(max_workers=16) as pool:
|
||||||
|
try:
|
||||||
|
for download in futures.as_completed((
|
||||||
|
pool.submit(
|
||||||
|
download_segment,
|
||||||
|
filename=str(i).zfill(len(str(len(master.segments)))),
|
||||||
|
segment=segment,
|
||||||
|
init_data=init_data,
|
||||||
|
segment_key=segment_key
|
||||||
|
)
|
||||||
|
for i, segment in enumerate(master.segments)
|
||||||
|
)):
|
||||||
|
if download.cancelled():
|
||||||
|
continue
|
||||||
|
e = download.exception()
|
||||||
|
if e:
|
||||||
|
state_event.set()
|
||||||
|
pool.shutdown(wait=False, cancel_futures=True)
|
||||||
|
traceback.print_exception(e)
|
||||||
|
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
pbar.update(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
state_event.set()
|
||||||
|
pool.shutdown(wait=False, cancel_futures=True)
|
||||||
|
log.info("Received Keyboard Interrupt, stopping...")
|
||||||
|
return
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_drm(
|
def get_drm(
|
||||||
keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]],
|
keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]],
|
||||||
|
|
Loading…
Reference in New Issue