Multi-thread the new DASH download system, improve redundency

Just like the commit for HLS multi-threading, this mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs.
This commit is contained in:
rlaphoenix 2023-02-22 01:06:03 +00:00
parent 9e6f5b25f3
commit 4e875f5ffc
1 changed files with 156 additions and 158 deletions

View File

@ -6,9 +6,14 @@ import logging
import math import math
import re import re
import sys import sys
import time
import traceback
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from copy import copy from copy import copy
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from threading import Event
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
from uuid import UUID from uuid import UUID
@ -303,7 +308,6 @@ class DASH:
else: else:
drm = None drm = None
segment_urls: list[str] = []
manifest = load_xml(session.get(manifest_url).text) manifest = load_xml(session.get(manifest_url).text)
manifest_url_query = urlparse(manifest_url).query manifest_url_query = urlparse(manifest_url).query
@ -312,20 +316,27 @@ class DASH:
period_base_url = urljoin(manifest_url, period_base_url) period_base_url = urljoin(manifest_url, period_base_url)
period_duration = period.get("duration") or manifest.get("mediaPresentationDuration") period_duration = period.get("duration") or manifest.get("mediaPresentationDuration")
init_data: Optional[bytes] = None
base_url = representation.findtext("BaseURL") or period_base_url base_url = representation.findtext("BaseURL") or period_base_url
segment_template = representation.find("SegmentTemplate") segment_template = representation.find("SegmentTemplate")
if segment_template is None: if segment_template is None:
segment_template = adaptation_set.find("SegmentTemplate") segment_template = adaptation_set.find("SegmentTemplate")
segment_base = representation.find("SegmentBase")
if segment_base is None:
segment_base = adaptation_set.find("SegmentBase")
segment_list = representation.find("SegmentList") segment_list = representation.find("SegmentList")
if segment_list is None: if segment_list is None:
segment_list = adaptation_set.find("SegmentList") segment_list = adaptation_set.find("SegmentList")
if segment_template is None and segment_list is None and base_url:
# If there's no SegmentTemplate and no SegmentList, then SegmentBase is used or just BaseURL
# Regardless which of the two is used, we can just directly grab the BaseURL
# Players would normally calculate segments via Byte-Ranges, but we don't care
track.url = urljoin(period_base_url, base_url)
track.descriptor = track.Descriptor.URL
track.drm = [drm] if drm else []
else:
segments: list[tuple[str, Optional[str]]] = []
if segment_template is not None: if segment_template is not None:
segment_template = copy(segment_template) segment_template = copy(segment_template)
start_number = int(segment_template.get("startNumber") or 1) start_number = int(segment_template.get("startNumber") or 1)
@ -343,6 +354,16 @@ class DASH:
value += f"?{manifest_url_query}" value += f"?{manifest_url_query}"
segment_template.set(item, value) segment_template.set(item, value)
init_url = segment_template.get("initialization")
if init_url:
res = session.get(DASH.replace_fields(
init_url,
Bandwidth=representation.get("bandwidth"),
RepresentationID=representation.get("id")
))
res.raise_for_status()
init_data = res.content
if segment_timeline is not None: if segment_timeline is not None:
seg_time_list = [] seg_time_list = []
current_time = 0 current_time = 0
@ -353,99 +374,38 @@ class DASH:
seg_time_list.append(current_time) seg_time_list.append(current_time)
current_time += int(s.get("d")) current_time += int(s.get("d"))
seg_num_list = list(range(start_number, len(seg_time_list) + start_number)) seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
segment_urls += [
for t, n in zip(seg_time_list, seg_num_list):
segments.append((
DASH.replace_fields( DASH.replace_fields(
segment_template.get("media"), segment_template.get("media"),
Bandwidth=representation.get("bandwidth"), Bandwidth=representation.get("bandwidth"),
Number=n, Number=n,
RepresentationID=representation.get("id"), RepresentationID=representation.get("id"),
Time=t Time=t
) ), None
for t, n in zip(seg_time_list, seg_num_list) ))
]
else: else:
if not period_duration: if not period_duration:
raise ValueError("Duration of the Period was unable to be determined.") raise ValueError("Duration of the Period was unable to be determined.")
period_duration = DASH.pt_to_sec(period_duration) period_duration = DASH.pt_to_sec(period_duration)
segment_duration = float(segment_template.get("duration")) segment_duration = float(segment_template.get("duration"))
segment_timescale = float(segment_template.get("timescale") or 1) segment_timescale = float(segment_template.get("timescale") or 1)
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale)) total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
segment_urls += [
for s in range(start_number, start_number + total_segments):
segments.append((
DASH.replace_fields( DASH.replace_fields(
segment_template.get("media"), segment_template.get("media"),
Bandwidth=representation.get("bandwidth"), Bandwidth=representation.get("bandwidth"),
Number=s, Number=s,
RepresentationID=representation.get("id"), RepresentationID=representation.get("id"),
Time=s Time=s
) ), None
for s in range(start_number, start_number + total_segments)
]
init_data = None
init_url = segment_template.get("initialization")
if init_url:
res = session.get(DASH.replace_fields(
init_url,
Bandwidth=representation.get("bandwidth"),
RepresentationID=representation.get("id")
)) ))
res.raise_for_status()
init_data = res.content
if not drm:
try:
drm = Widevine.from_init_data(init_data)
except Widevine.Exceptions.PSSHNotFound:
# it might not have Widevine DRM, or might not have found the PSSH
log.warning("No Widevine PSSH was found for this track, is it DRM free?")
else:
# license and grab content keys
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm)
for i, segment_url in enumerate(tqdm(segment_urls, unit="segments")):
segment_filename = str(i).zfill(len(str(len(segment_urls))))
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
asyncio.run(aria2c(
segment_url,
segment_save_path,
session.headers,
proxy
))
if isinstance(track, Audio) or init_data:
with open(segment_save_path, "rb+") as f:
segment_data = f.read()
if isinstance(track, Audio):
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Is this in mpeg data, or init data?
segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
# prepend the init data to be able to decrypt
if init_data:
f.seek(0)
f.write(init_data)
f.write(segment_data)
if drm:
# TODO: What if the manifest does not mention DRM, but has DRM
drm.decrypt(segment_save_path)
track.drm = None
if callable(track.OnDecrypted):
track.OnDecrypted(track)
elif segment_list is not None: elif segment_list is not None:
base_media_url = urljoin(period_base_url, base_url) base_media_url = urljoin(period_base_url, base_url)
if any(x.get("media") is not None for x in segment_list.findall("SegmentURL")):
# at least one segment has no URL specified, it uses the base url and ranges
track.url = base_media_url
track.descriptor = track.Descriptor.URL
track.drm = [drm] if drm else []
else:
init_data = None init_data = None
initialization = segment_list.find("Initialization") initialization = segment_list.find("Initialization")
if initialization: if initialization:
@ -456,7 +416,23 @@ class DASH:
res = session.get(source_url) res = session.get(source_url)
res.raise_for_status() res.raise_for_status()
init_data = res.content init_data = res.content
if not drm:
segment_urls = segment_list.findall("SegmentURL")
for segment_url in segment_urls:
media_url = segment_url.get("media")
if media_url is None:
media_url = base_media_url
segments.append((
media_url,
segment_url.get("mediaRange")
))
else:
log.error("Could not find a way to get segments from this MPD manifest.")
log.debug(manifest_url)
sys.exit(1)
if not drm and isinstance(track, (Video, Audio)):
try: try:
drm = Widevine.from_init_data(init_data) drm = Widevine.from_init_data(init_data)
except Widevine.Exceptions.PSSHNotFound: except Widevine.Exceptions.PSSHNotFound:
@ -468,20 +444,23 @@ class DASH:
raise ValueError("license_widevine func must be supplied to use Widevine DRM") raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm) license_widevine(drm)
for i, segment_url in enumerate(tqdm(segment_list.findall("SegmentURL"), unit="segments")): state_event = Event()
segment_filename = str(i).zfill(len(str(len(segment_urls))))
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
media_url = segment_url.get("media") def download_segment(filename: str, segment: tuple[str, Optional[str]]):
if media_url is None: time.sleep(0.1)
media_url = base_media_url if state_event.is_set():
return
segment_save_path = (save_dir / filename).with_suffix(".mp4")
segment_uri, segment_range = segment
asyncio.run(aria2c( asyncio.run(aria2c(
media_url, segment_uri,
segment_save_path, segment_save_path,
session.headers, session.headers,
proxy, proxy,
byte_range=segment_url.get("mediaRange") byte_range=segment_range
)) ))
if isinstance(track, Audio) or init_data: if isinstance(track, Audio) or init_data:
@ -507,15 +486,34 @@ class DASH:
track.drm = None track.drm = None
if callable(track.OnDecrypted): if callable(track.OnDecrypted):
track.OnDecrypted(track) track.OnDecrypted(track)
elif segment_base is not None or base_url:
# SegmentBase more or less boils down to defined ByteRanges with tqdm(total=len(segments), unit="segments") as pbar:
# So, we don't care, just download the full file with ThreadPoolExecutor(max_workers=16) as pool:
track.url = urljoin(period_base_url, base_url) try:
track.descriptor = track.Descriptor.URL for download in futures.as_completed((
track.drm = [drm] if drm else [] pool.submit(
else: download_segment,
log.error("Could not find a way to get segments from this MPD manifest.") filename=str(i).zfill(len(str(len(segments)))),
segment=segment
)
for i, segment in enumerate(segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1) sys.exit(1)
else:
pbar.update(1)
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
log.info("Received Keyboard Interrupt, stopping...")
return
@staticmethod @staticmethod
def get_language(*options: Any) -> Optional[Language]: def get_language(*options: Any) -> Optional[Language]: