Multi-thread the new DASH download system, improve redundency

Just like the commit for HLS multi-threading, this mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs.
This commit is contained in:
rlaphoenix 2023-02-22 01:06:03 +00:00
parent 9e6f5b25f3
commit 4e875f5ffc
1 changed files with 156 additions and 158 deletions

View File

@ -6,9 +6,14 @@ import logging
import math
import re
import sys
import time
import traceback
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from hashlib import md5
from pathlib import Path
from threading import Event
from typing import Any, Callable, Optional, Union
from urllib.parse import urljoin, urlparse
from uuid import UUID
@ -303,7 +308,6 @@ class DASH:
else:
drm = None
segment_urls: list[str] = []
manifest = load_xml(session.get(manifest_url).text)
manifest_url_query = urlparse(manifest_url).query
@ -312,20 +316,27 @@ class DASH:
period_base_url = urljoin(manifest_url, period_base_url)
period_duration = period.get("duration") or manifest.get("mediaPresentationDuration")
init_data: Optional[bytes] = None
base_url = representation.findtext("BaseURL") or period_base_url
segment_template = representation.find("SegmentTemplate")
if segment_template is None:
segment_template = adaptation_set.find("SegmentTemplate")
segment_base = representation.find("SegmentBase")
if segment_base is None:
segment_base = adaptation_set.find("SegmentBase")
segment_list = representation.find("SegmentList")
if segment_list is None:
segment_list = adaptation_set.find("SegmentList")
if segment_template is None and segment_list is None and base_url:
# If there's no SegmentTemplate and no SegmentList, then SegmentBase is used or just BaseURL
# Regardless which of the two is used, we can just directly grab the BaseURL
# Players would normally calculate segments via Byte-Ranges, but we don't care
track.url = urljoin(period_base_url, base_url)
track.descriptor = track.Descriptor.URL
track.drm = [drm] if drm else []
else:
segments: list[tuple[str, Optional[str]]] = []
if segment_template is not None:
segment_template = copy(segment_template)
start_number = int(segment_template.get("startNumber") or 1)
@ -343,6 +354,16 @@ class DASH:
value += f"?{manifest_url_query}"
segment_template.set(item, value)
init_url = segment_template.get("initialization")
if init_url:
res = session.get(DASH.replace_fields(
init_url,
Bandwidth=representation.get("bandwidth"),
RepresentationID=representation.get("id")
))
res.raise_for_status()
init_data = res.content
if segment_timeline is not None:
seg_time_list = []
current_time = 0
@ -353,99 +374,38 @@ class DASH:
seg_time_list.append(current_time)
current_time += int(s.get("d"))
seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
segment_urls += [
for t, n in zip(seg_time_list, seg_num_list):
segments.append((
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=n,
RepresentationID=representation.get("id"),
Time=t
)
for t, n in zip(seg_time_list, seg_num_list)
]
), None
))
else:
if not period_duration:
raise ValueError("Duration of the Period was unable to be determined.")
period_duration = DASH.pt_to_sec(period_duration)
segment_duration = float(segment_template.get("duration"))
segment_timescale = float(segment_template.get("timescale") or 1)
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
segment_urls += [
for s in range(start_number, start_number + total_segments):
segments.append((
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=s,
RepresentationID=representation.get("id"),
Time=s
)
for s in range(start_number, start_number + total_segments)
]
init_data = None
init_url = segment_template.get("initialization")
if init_url:
res = session.get(DASH.replace_fields(
init_url,
Bandwidth=representation.get("bandwidth"),
RepresentationID=representation.get("id")
), None
))
res.raise_for_status()
init_data = res.content
if not drm:
try:
drm = Widevine.from_init_data(init_data)
except Widevine.Exceptions.PSSHNotFound:
# it might not have Widevine DRM, or might not have found the PSSH
log.warning("No Widevine PSSH was found for this track, is it DRM free?")
else:
# license and grab content keys
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm)
for i, segment_url in enumerate(tqdm(segment_urls, unit="segments")):
segment_filename = str(i).zfill(len(str(len(segment_urls))))
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
asyncio.run(aria2c(
segment_url,
segment_save_path,
session.headers,
proxy
))
if isinstance(track, Audio) or init_data:
with open(segment_save_path, "rb+") as f:
segment_data = f.read()
if isinstance(track, Audio):
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Is this in mpeg data, or init data?
segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
# prepend the init data to be able to decrypt
if init_data:
f.seek(0)
f.write(init_data)
f.write(segment_data)
if drm:
# TODO: What if the manifest does not mention DRM, but has DRM
drm.decrypt(segment_save_path)
track.drm = None
if callable(track.OnDecrypted):
track.OnDecrypted(track)
elif segment_list is not None:
base_media_url = urljoin(period_base_url, base_url)
if any(x.get("media") is not None for x in segment_list.findall("SegmentURL")):
# at least one segment has no URL specified, it uses the base url and ranges
track.url = base_media_url
track.descriptor = track.Descriptor.URL
track.drm = [drm] if drm else []
else:
init_data = None
initialization = segment_list.find("Initialization")
if initialization:
@ -456,7 +416,23 @@ class DASH:
res = session.get(source_url)
res.raise_for_status()
init_data = res.content
if not drm:
segment_urls = segment_list.findall("SegmentURL")
for segment_url in segment_urls:
media_url = segment_url.get("media")
if media_url is None:
media_url = base_media_url
segments.append((
media_url,
segment_url.get("mediaRange")
))
else:
log.error("Could not find a way to get segments from this MPD manifest.")
log.debug(manifest_url)
sys.exit(1)
if not drm and isinstance(track, (Video, Audio)):
try:
drm = Widevine.from_init_data(init_data)
except Widevine.Exceptions.PSSHNotFound:
@ -468,20 +444,23 @@ class DASH:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm)
for i, segment_url in enumerate(tqdm(segment_list.findall("SegmentURL"), unit="segments")):
segment_filename = str(i).zfill(len(str(len(segment_urls))))
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
state_event = Event()
media_url = segment_url.get("media")
if media_url is None:
media_url = base_media_url
def download_segment(filename: str, segment: tuple[str, Optional[str]]):
time.sleep(0.1)
if state_event.is_set():
return
segment_save_path = (save_dir / filename).with_suffix(".mp4")
segment_uri, segment_range = segment
asyncio.run(aria2c(
media_url,
segment_uri,
segment_save_path,
session.headers,
proxy,
byte_range=segment_url.get("mediaRange")
byte_range=segment_range
))
if isinstance(track, Audio) or init_data:
@ -507,15 +486,34 @@ class DASH:
track.drm = None
if callable(track.OnDecrypted):
track.OnDecrypted(track)
elif segment_base is not None or base_url:
# SegmentBase more or less boils down to defined ByteRanges
# So, we don't care, just download the full file
track.url = urljoin(period_base_url, base_url)
track.descriptor = track.Descriptor.URL
track.drm = [drm] if drm else []
else:
log.error("Could not find a way to get segments from this MPD manifest.")
with tqdm(total=len(segments), unit="segments") as pbar:
with ThreadPoolExecutor(max_workers=16) as pool:
try:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(segments)))),
segment=segment
)
for i, segment in enumerate(segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
pbar.update(1)
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
log.info("Received Keyboard Interrupt, stopping...")
return
@staticmethod
def get_language(*options: Any) -> Optional[Language]: