mirror of https://github.com/devine-dl/devine.git
Multi-thread the new DASH download system, improve redundency
Just like the commit for HLS multi-threading, this mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs.
This commit is contained in:
parent
9e6f5b25f3
commit
4e875f5ffc
|
@ -6,9 +6,14 @@ import logging
|
|||
import math
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from concurrent import futures
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import copy
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from threading import Event
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from urllib.parse import urljoin, urlparse
|
||||
from uuid import UUID
|
||||
|
@ -303,7 +308,6 @@ class DASH:
|
|||
else:
|
||||
drm = None
|
||||
|
||||
segment_urls: list[str] = []
|
||||
manifest = load_xml(session.get(manifest_url).text)
|
||||
manifest_url_query = urlparse(manifest_url).query
|
||||
|
||||
|
@ -312,20 +316,27 @@ class DASH:
|
|||
period_base_url = urljoin(manifest_url, period_base_url)
|
||||
period_duration = period.get("duration") or manifest.get("mediaPresentationDuration")
|
||||
|
||||
init_data: Optional[bytes] = None
|
||||
base_url = representation.findtext("BaseURL") or period_base_url
|
||||
|
||||
segment_template = representation.find("SegmentTemplate")
|
||||
if segment_template is None:
|
||||
segment_template = adaptation_set.find("SegmentTemplate")
|
||||
|
||||
segment_base = representation.find("SegmentBase")
|
||||
if segment_base is None:
|
||||
segment_base = adaptation_set.find("SegmentBase")
|
||||
|
||||
segment_list = representation.find("SegmentList")
|
||||
if segment_list is None:
|
||||
segment_list = adaptation_set.find("SegmentList")
|
||||
|
||||
if segment_template is None and segment_list is None and base_url:
|
||||
# If there's no SegmentTemplate and no SegmentList, then SegmentBase is used or just BaseURL
|
||||
# Regardless which of the two is used, we can just directly grab the BaseURL
|
||||
# Players would normally calculate segments via Byte-Ranges, but we don't care
|
||||
track.url = urljoin(period_base_url, base_url)
|
||||
track.descriptor = track.Descriptor.URL
|
||||
track.drm = [drm] if drm else []
|
||||
else:
|
||||
segments: list[tuple[str, Optional[str]]] = []
|
||||
|
||||
if segment_template is not None:
|
||||
segment_template = copy(segment_template)
|
||||
start_number = int(segment_template.get("startNumber") or 1)
|
||||
|
@ -343,6 +354,16 @@ class DASH:
|
|||
value += f"?{manifest_url_query}"
|
||||
segment_template.set(item, value)
|
||||
|
||||
init_url = segment_template.get("initialization")
|
||||
if init_url:
|
||||
res = session.get(DASH.replace_fields(
|
||||
init_url,
|
||||
Bandwidth=representation.get("bandwidth"),
|
||||
RepresentationID=representation.get("id")
|
||||
))
|
||||
res.raise_for_status()
|
||||
init_data = res.content
|
||||
|
||||
if segment_timeline is not None:
|
||||
seg_time_list = []
|
||||
current_time = 0
|
||||
|
@ -353,99 +374,38 @@ class DASH:
|
|||
seg_time_list.append(current_time)
|
||||
current_time += int(s.get("d"))
|
||||
seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
|
||||
segment_urls += [
|
||||
|
||||
for t, n in zip(seg_time_list, seg_num_list):
|
||||
segments.append((
|
||||
DASH.replace_fields(
|
||||
segment_template.get("media"),
|
||||
Bandwidth=representation.get("bandwidth"),
|
||||
Number=n,
|
||||
RepresentationID=representation.get("id"),
|
||||
Time=t
|
||||
)
|
||||
for t, n in zip(seg_time_list, seg_num_list)
|
||||
]
|
||||
), None
|
||||
))
|
||||
else:
|
||||
if not period_duration:
|
||||
raise ValueError("Duration of the Period was unable to be determined.")
|
||||
period_duration = DASH.pt_to_sec(period_duration)
|
||||
segment_duration = float(segment_template.get("duration"))
|
||||
segment_timescale = float(segment_template.get("timescale") or 1)
|
||||
|
||||
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
|
||||
segment_urls += [
|
||||
|
||||
for s in range(start_number, start_number + total_segments):
|
||||
segments.append((
|
||||
DASH.replace_fields(
|
||||
segment_template.get("media"),
|
||||
Bandwidth=representation.get("bandwidth"),
|
||||
Number=s,
|
||||
RepresentationID=representation.get("id"),
|
||||
Time=s
|
||||
)
|
||||
for s in range(start_number, start_number + total_segments)
|
||||
]
|
||||
|
||||
init_data = None
|
||||
init_url = segment_template.get("initialization")
|
||||
if init_url:
|
||||
res = session.get(DASH.replace_fields(
|
||||
init_url,
|
||||
Bandwidth=representation.get("bandwidth"),
|
||||
RepresentationID=representation.get("id")
|
||||
), None
|
||||
))
|
||||
res.raise_for_status()
|
||||
init_data = res.content
|
||||
if not drm:
|
||||
try:
|
||||
drm = Widevine.from_init_data(init_data)
|
||||
except Widevine.Exceptions.PSSHNotFound:
|
||||
# it might not have Widevine DRM, or might not have found the PSSH
|
||||
log.warning("No Widevine PSSH was found for this track, is it DRM free?")
|
||||
else:
|
||||
# license and grab content keys
|
||||
if not license_widevine:
|
||||
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
|
||||
license_widevine(drm)
|
||||
|
||||
for i, segment_url in enumerate(tqdm(segment_urls, unit="segments")):
|
||||
segment_filename = str(i).zfill(len(str(len(segment_urls))))
|
||||
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
|
||||
|
||||
asyncio.run(aria2c(
|
||||
segment_url,
|
||||
segment_save_path,
|
||||
session.headers,
|
||||
proxy
|
||||
))
|
||||
|
||||
if isinstance(track, Audio) or init_data:
|
||||
with open(segment_save_path, "rb+") as f:
|
||||
segment_data = f.read()
|
||||
if isinstance(track, Audio):
|
||||
# fix audio decryption on ATVP by fixing the sample description index
|
||||
# TODO: Is this in mpeg data, or init data?
|
||||
segment_data = re.sub(
|
||||
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
|
||||
b"\\g<1>\x01",
|
||||
segment_data
|
||||
)
|
||||
# prepend the init data to be able to decrypt
|
||||
if init_data:
|
||||
f.seek(0)
|
||||
f.write(init_data)
|
||||
f.write(segment_data)
|
||||
|
||||
if drm:
|
||||
# TODO: What if the manifest does not mention DRM, but has DRM
|
||||
drm.decrypt(segment_save_path)
|
||||
track.drm = None
|
||||
if callable(track.OnDecrypted):
|
||||
track.OnDecrypted(track)
|
||||
elif segment_list is not None:
|
||||
base_media_url = urljoin(period_base_url, base_url)
|
||||
if any(x.get("media") is not None for x in segment_list.findall("SegmentURL")):
|
||||
# at least one segment has no URL specified, it uses the base url and ranges
|
||||
track.url = base_media_url
|
||||
track.descriptor = track.Descriptor.URL
|
||||
track.drm = [drm] if drm else []
|
||||
else:
|
||||
|
||||
init_data = None
|
||||
initialization = segment_list.find("Initialization")
|
||||
if initialization:
|
||||
|
@ -456,7 +416,23 @@ class DASH:
|
|||
res = session.get(source_url)
|
||||
res.raise_for_status()
|
||||
init_data = res.content
|
||||
if not drm:
|
||||
|
||||
segment_urls = segment_list.findall("SegmentURL")
|
||||
for segment_url in segment_urls:
|
||||
media_url = segment_url.get("media")
|
||||
if media_url is None:
|
||||
media_url = base_media_url
|
||||
|
||||
segments.append((
|
||||
media_url,
|
||||
segment_url.get("mediaRange")
|
||||
))
|
||||
else:
|
||||
log.error("Could not find a way to get segments from this MPD manifest.")
|
||||
log.debug(manifest_url)
|
||||
sys.exit(1)
|
||||
|
||||
if not drm and isinstance(track, (Video, Audio)):
|
||||
try:
|
||||
drm = Widevine.from_init_data(init_data)
|
||||
except Widevine.Exceptions.PSSHNotFound:
|
||||
|
@ -468,20 +444,23 @@ class DASH:
|
|||
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
|
||||
license_widevine(drm)
|
||||
|
||||
for i, segment_url in enumerate(tqdm(segment_list.findall("SegmentURL"), unit="segments")):
|
||||
segment_filename = str(i).zfill(len(str(len(segment_urls))))
|
||||
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
|
||||
state_event = Event()
|
||||
|
||||
media_url = segment_url.get("media")
|
||||
if media_url is None:
|
||||
media_url = base_media_url
|
||||
def download_segment(filename: str, segment: tuple[str, Optional[str]]):
|
||||
time.sleep(0.1)
|
||||
if state_event.is_set():
|
||||
return
|
||||
|
||||
segment_save_path = (save_dir / filename).with_suffix(".mp4")
|
||||
|
||||
segment_uri, segment_range = segment
|
||||
|
||||
asyncio.run(aria2c(
|
||||
media_url,
|
||||
segment_uri,
|
||||
segment_save_path,
|
||||
session.headers,
|
||||
proxy,
|
||||
byte_range=segment_url.get("mediaRange")
|
||||
byte_range=segment_range
|
||||
))
|
||||
|
||||
if isinstance(track, Audio) or init_data:
|
||||
|
@ -507,15 +486,34 @@ class DASH:
|
|||
track.drm = None
|
||||
if callable(track.OnDecrypted):
|
||||
track.OnDecrypted(track)
|
||||
elif segment_base is not None or base_url:
|
||||
# SegmentBase more or less boils down to defined ByteRanges
|
||||
# So, we don't care, just download the full file
|
||||
track.url = urljoin(period_base_url, base_url)
|
||||
track.descriptor = track.Descriptor.URL
|
||||
track.drm = [drm] if drm else []
|
||||
else:
|
||||
log.error("Could not find a way to get segments from this MPD manifest.")
|
||||
|
||||
with tqdm(total=len(segments), unit="segments") as pbar:
|
||||
with ThreadPoolExecutor(max_workers=16) as pool:
|
||||
try:
|
||||
for download in futures.as_completed((
|
||||
pool.submit(
|
||||
download_segment,
|
||||
filename=str(i).zfill(len(str(len(segments)))),
|
||||
segment=segment
|
||||
)
|
||||
for i, segment in enumerate(segments)
|
||||
)):
|
||||
if download.cancelled():
|
||||
continue
|
||||
e = download.exception()
|
||||
if e:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
traceback.print_exception(e)
|
||||
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
pbar.update(1)
|
||||
except KeyboardInterrupt:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
log.info("Received Keyboard Interrupt, stopping...")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_language(*options: Any) -> Optional[Language]:
|
||||
|
|
Loading…
Reference in New Issue