Multi-thread the new DASH download system, improve redundency

Just like the commit for HLS multi-threading, this mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs.
This commit is contained in:
rlaphoenix 2023-02-22 01:06:03 +00:00
parent 9e6f5b25f3
commit 4e875f5ffc
1 changed files with 156 additions and 158 deletions

View File

@ -6,9 +6,14 @@ import logging
import math import math
import re import re
import sys import sys
import time
import traceback
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from copy import copy from copy import copy
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from threading import Event
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
from uuid import UUID from uuid import UUID
@ -303,7 +308,6 @@ class DASH:
else: else:
drm = None drm = None
segment_urls: list[str] = []
manifest = load_xml(session.get(manifest_url).text) manifest = load_xml(session.get(manifest_url).text)
manifest_url_query = urlparse(manifest_url).query manifest_url_query = urlparse(manifest_url).query
@ -312,107 +316,151 @@ class DASH:
period_base_url = urljoin(manifest_url, period_base_url) period_base_url = urljoin(manifest_url, period_base_url)
period_duration = period.get("duration") or manifest.get("mediaPresentationDuration") period_duration = period.get("duration") or manifest.get("mediaPresentationDuration")
init_data: Optional[bytes] = None
base_url = representation.findtext("BaseURL") or period_base_url base_url = representation.findtext("BaseURL") or period_base_url
segment_template = representation.find("SegmentTemplate") segment_template = representation.find("SegmentTemplate")
if segment_template is None: if segment_template is None:
segment_template = adaptation_set.find("SegmentTemplate") segment_template = adaptation_set.find("SegmentTemplate")
segment_base = representation.find("SegmentBase")
if segment_base is None:
segment_base = adaptation_set.find("SegmentBase")
segment_list = representation.find("SegmentList") segment_list = representation.find("SegmentList")
if segment_list is None: if segment_list is None:
segment_list = adaptation_set.find("SegmentList") segment_list = adaptation_set.find("SegmentList")
if segment_template is not None: if segment_template is None and segment_list is None and base_url:
segment_template = copy(segment_template) # If there's no SegmentTemplate and no SegmentList, then SegmentBase is used or just BaseURL
start_number = int(segment_template.get("startNumber") or 1) # Regardless which of the two is used, we can just directly grab the BaseURL
segment_timeline = segment_template.find("SegmentTimeline") # Players would normally calculate segments via Byte-Ranges, but we don't care
track.url = urljoin(period_base_url, base_url)
track.descriptor = track.Descriptor.URL
track.drm = [drm] if drm else []
else:
segments: list[tuple[str, Optional[str]]] = []
for item in ("initialization", "media"): if segment_template is not None:
value = segment_template.get(item) segment_template = copy(segment_template)
if not value: start_number = int(segment_template.get("startNumber") or 1)
continue segment_timeline = segment_template.find("SegmentTimeline")
if not re.match("^https?://", value, re.IGNORECASE):
if not base_url:
raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.")
value = urljoin(base_url, value)
if not urlparse(value).query and manifest_url_query:
value += f"?{manifest_url_query}"
segment_template.set(item, value)
if segment_timeline is not None: for item in ("initialization", "media"):
seg_time_list = [] value = segment_template.get(item)
current_time = 0 if not value:
for s in segment_timeline.findall("S"): continue
if s.get("t"): if not re.match("^https?://", value, re.IGNORECASE):
current_time = int(s.get("t")) if not base_url:
for _ in range(1 + (int(s.get("r") or 0))): raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.")
seg_time_list.append(current_time) value = urljoin(base_url, value)
current_time += int(s.get("d")) if not urlparse(value).query and manifest_url_query:
seg_num_list = list(range(start_number, len(seg_time_list) + start_number)) value += f"?{manifest_url_query}"
segment_urls += [ segment_template.set(item, value)
DASH.replace_fields(
segment_template.get("media"), init_url = segment_template.get("initialization")
if init_url:
res = session.get(DASH.replace_fields(
init_url,
Bandwidth=representation.get("bandwidth"), Bandwidth=representation.get("bandwidth"),
Number=n, RepresentationID=representation.get("id")
RepresentationID=representation.get("id"), ))
Time=t res.raise_for_status()
) init_data = res.content
for t, n in zip(seg_time_list, seg_num_list)
] if segment_timeline is not None:
seg_time_list = []
current_time = 0
for s in segment_timeline.findall("S"):
if s.get("t"):
current_time = int(s.get("t"))
for _ in range(1 + (int(s.get("r") or 0))):
seg_time_list.append(current_time)
current_time += int(s.get("d"))
seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
for t, n in zip(seg_time_list, seg_num_list):
segments.append((
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=n,
RepresentationID=representation.get("id"),
Time=t
), None
))
else:
if not period_duration:
raise ValueError("Duration of the Period was unable to be determined.")
period_duration = DASH.pt_to_sec(period_duration)
segment_duration = float(segment_template.get("duration"))
segment_timescale = float(segment_template.get("timescale") or 1)
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
for s in range(start_number, start_number + total_segments):
segments.append((
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=s,
RepresentationID=representation.get("id"),
Time=s
), None
))
elif segment_list is not None:
base_media_url = urljoin(period_base_url, base_url)
init_data = None
initialization = segment_list.find("Initialization")
if initialization:
source_url = initialization.get("sourceURL")
if source_url is None:
source_url = base_media_url
res = session.get(source_url)
res.raise_for_status()
init_data = res.content
segment_urls = segment_list.findall("SegmentURL")
for segment_url in segment_urls:
media_url = segment_url.get("media")
if media_url is None:
media_url = base_media_url
segments.append((
media_url,
segment_url.get("mediaRange")
))
else: else:
if not period_duration: log.error("Could not find a way to get segments from this MPD manifest.")
raise ValueError("Duration of the Period was unable to be determined.") log.debug(manifest_url)
period_duration = DASH.pt_to_sec(period_duration) sys.exit(1)
segment_duration = float(segment_template.get("duration"))
segment_timescale = float(segment_template.get("timescale") or 1)
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale)) if not drm and isinstance(track, (Video, Audio)):
segment_urls += [ try:
DASH.replace_fields( drm = Widevine.from_init_data(init_data)
segment_template.get("media"), except Widevine.Exceptions.PSSHNotFound:
Bandwidth=representation.get("bandwidth"), # it might not have Widevine DRM, or might not have found the PSSH
Number=s, log.warning("No Widevine PSSH was found for this track, is it DRM free?")
RepresentationID=representation.get("id"), else:
Time=s # license and grab content keys
) if not license_widevine:
for s in range(start_number, start_number + total_segments) raise ValueError("license_widevine func must be supplied to use Widevine DRM")
] license_widevine(drm)
init_data = None state_event = Event()
init_url = segment_template.get("initialization")
if init_url:
res = session.get(DASH.replace_fields(
init_url,
Bandwidth=representation.get("bandwidth"),
RepresentationID=representation.get("id")
))
res.raise_for_status()
init_data = res.content
if not drm:
try:
drm = Widevine.from_init_data(init_data)
except Widevine.Exceptions.PSSHNotFound:
# it might not have Widevine DRM, or might not have found the PSSH
log.warning("No Widevine PSSH was found for this track, is it DRM free?")
else:
# license and grab content keys
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm)
for i, segment_url in enumerate(tqdm(segment_urls, unit="segments")): def download_segment(filename: str, segment: tuple[str, Optional[str]]):
segment_filename = str(i).zfill(len(str(len(segment_urls)))) time.sleep(0.1)
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4") if state_event.is_set():
return
segment_save_path = (save_dir / filename).with_suffix(".mp4")
segment_uri, segment_range = segment
asyncio.run(aria2c( asyncio.run(aria2c(
segment_url, segment_uri,
segment_save_path, segment_save_path,
session.headers, session.headers,
proxy proxy,
byte_range=segment_range
)) ))
if isinstance(track, Audio) or init_data: if isinstance(track, Audio) or init_data:
@ -438,84 +486,34 @@ class DASH:
track.drm = None track.drm = None
if callable(track.OnDecrypted): if callable(track.OnDecrypted):
track.OnDecrypted(track) track.OnDecrypted(track)
elif segment_list is not None:
base_media_url = urljoin(period_base_url, base_url)
if any(x.get("media") is not None for x in segment_list.findall("SegmentURL")):
# at least one segment has no URL specified, it uses the base url and ranges
track.url = base_media_url
track.descriptor = track.Descriptor.URL
track.drm = [drm] if drm else []
else:
init_data = None
initialization = segment_list.find("Initialization")
if initialization:
source_url = initialization.get("sourceURL")
if source_url is None:
source_url = base_media_url
res = session.get(source_url) with tqdm(total=len(segments), unit="segments") as pbar:
res.raise_for_status() with ThreadPoolExecutor(max_workers=16) as pool:
init_data = res.content try:
if not drm: for download in futures.as_completed((
try: pool.submit(
drm = Widevine.from_init_data(init_data) download_segment,
except Widevine.Exceptions.PSSHNotFound: filename=str(i).zfill(len(str(len(segments)))),
# it might not have Widevine DRM, or might not have found the PSSH segment=segment
log.warning("No Widevine PSSH was found for this track, is it DRM free?") )
else: for i, segment in enumerate(segments)
# license and grab content keys )):
if not license_widevine: if download.cancelled():
raise ValueError("license_widevine func must be supplied to use Widevine DRM") continue
license_widevine(drm) e = download.exception()
if e:
for i, segment_url in enumerate(tqdm(segment_list.findall("SegmentURL"), unit="segments")): state_event.set()
segment_filename = str(i).zfill(len(str(len(segment_urls)))) pool.shutdown(wait=False, cancel_futures=True)
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4") traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
media_url = segment_url.get("media") sys.exit(1)
if media_url is None: else:
media_url = base_media_url pbar.update(1)
except KeyboardInterrupt:
asyncio.run(aria2c( state_event.set()
media_url, pool.shutdown(wait=False, cancel_futures=True)
segment_save_path, log.info("Received Keyboard Interrupt, stopping...")
session.headers, return
proxy,
byte_range=segment_url.get("mediaRange")
))
if isinstance(track, Audio) or init_data:
with open(segment_save_path, "rb+") as f:
segment_data = f.read()
if isinstance(track, Audio):
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Is this in mpeg data, or init data?
segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
# prepend the init data to be able to decrypt
if init_data:
f.seek(0)
f.write(init_data)
f.write(segment_data)
if drm:
# TODO: What if the manifest does not mention DRM, but has DRM
drm.decrypt(segment_save_path)
track.drm = None
if callable(track.OnDecrypted):
track.OnDecrypted(track)
elif segment_base is not None or base_url:
# SegmentBase more or less boils down to defined ByteRanges
# So, we don't care, just download the full file
track.url = urljoin(period_base_url, base_url)
track.descriptor = track.Descriptor.URL
track.drm = [drm] if drm else []
else:
log.error("Could not find a way to get segments from this MPD manifest.")
sys.exit(1)
@staticmethod @staticmethod
def get_language(*options: Any) -> Optional[Language]: def get_language(*options: Any) -> Optional[Language]: