mirror of https://github.com/devine-dl/devine.git
Multi-thread the new DASH download system, improve redundency
Just like the commit for HLS multi-threading, this mimics the -j=16 system of aria2c, but manually via a ThreadPoolExecutor. Benefits of this is we still keep support for the new system, and we now get a useful progress bar via TQDM on segmented downloads, unlike aria2c which essentially fills the terminal with jumbled download progress stubs.
This commit is contained in:
parent
9e6f5b25f3
commit
4e875f5ffc
|
@ -6,9 +6,14 @@ import logging
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from concurrent import futures
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from threading import Event
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
from urllib.parse import urljoin, urlparse
|
from urllib.parse import urljoin, urlparse
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
@ -303,7 +308,6 @@ class DASH:
|
||||||
else:
|
else:
|
||||||
drm = None
|
drm = None
|
||||||
|
|
||||||
segment_urls: list[str] = []
|
|
||||||
manifest = load_xml(session.get(manifest_url).text)
|
manifest = load_xml(session.get(manifest_url).text)
|
||||||
manifest_url_query = urlparse(manifest_url).query
|
manifest_url_query = urlparse(manifest_url).query
|
||||||
|
|
||||||
|
@ -312,20 +316,27 @@ class DASH:
|
||||||
period_base_url = urljoin(manifest_url, period_base_url)
|
period_base_url = urljoin(manifest_url, period_base_url)
|
||||||
period_duration = period.get("duration") or manifest.get("mediaPresentationDuration")
|
period_duration = period.get("duration") or manifest.get("mediaPresentationDuration")
|
||||||
|
|
||||||
|
init_data: Optional[bytes] = None
|
||||||
base_url = representation.findtext("BaseURL") or period_base_url
|
base_url = representation.findtext("BaseURL") or period_base_url
|
||||||
|
|
||||||
segment_template = representation.find("SegmentTemplate")
|
segment_template = representation.find("SegmentTemplate")
|
||||||
if segment_template is None:
|
if segment_template is None:
|
||||||
segment_template = adaptation_set.find("SegmentTemplate")
|
segment_template = adaptation_set.find("SegmentTemplate")
|
||||||
|
|
||||||
segment_base = representation.find("SegmentBase")
|
|
||||||
if segment_base is None:
|
|
||||||
segment_base = adaptation_set.find("SegmentBase")
|
|
||||||
|
|
||||||
segment_list = representation.find("SegmentList")
|
segment_list = representation.find("SegmentList")
|
||||||
if segment_list is None:
|
if segment_list is None:
|
||||||
segment_list = adaptation_set.find("SegmentList")
|
segment_list = adaptation_set.find("SegmentList")
|
||||||
|
|
||||||
|
if segment_template is None and segment_list is None and base_url:
|
||||||
|
# If there's no SegmentTemplate and no SegmentList, then SegmentBase is used or just BaseURL
|
||||||
|
# Regardless which of the two is used, we can just directly grab the BaseURL
|
||||||
|
# Players would normally calculate segments via Byte-Ranges, but we don't care
|
||||||
|
track.url = urljoin(period_base_url, base_url)
|
||||||
|
track.descriptor = track.Descriptor.URL
|
||||||
|
track.drm = [drm] if drm else []
|
||||||
|
else:
|
||||||
|
segments: list[tuple[str, Optional[str]]] = []
|
||||||
|
|
||||||
if segment_template is not None:
|
if segment_template is not None:
|
||||||
segment_template = copy(segment_template)
|
segment_template = copy(segment_template)
|
||||||
start_number = int(segment_template.get("startNumber") or 1)
|
start_number = int(segment_template.get("startNumber") or 1)
|
||||||
|
@ -343,6 +354,16 @@ class DASH:
|
||||||
value += f"?{manifest_url_query}"
|
value += f"?{manifest_url_query}"
|
||||||
segment_template.set(item, value)
|
segment_template.set(item, value)
|
||||||
|
|
||||||
|
init_url = segment_template.get("initialization")
|
||||||
|
if init_url:
|
||||||
|
res = session.get(DASH.replace_fields(
|
||||||
|
init_url,
|
||||||
|
Bandwidth=representation.get("bandwidth"),
|
||||||
|
RepresentationID=representation.get("id")
|
||||||
|
))
|
||||||
|
res.raise_for_status()
|
||||||
|
init_data = res.content
|
||||||
|
|
||||||
if segment_timeline is not None:
|
if segment_timeline is not None:
|
||||||
seg_time_list = []
|
seg_time_list = []
|
||||||
current_time = 0
|
current_time = 0
|
||||||
|
@ -353,99 +374,38 @@ class DASH:
|
||||||
seg_time_list.append(current_time)
|
seg_time_list.append(current_time)
|
||||||
current_time += int(s.get("d"))
|
current_time += int(s.get("d"))
|
||||||
seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
|
seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
|
||||||
segment_urls += [
|
|
||||||
|
for t, n in zip(seg_time_list, seg_num_list):
|
||||||
|
segments.append((
|
||||||
DASH.replace_fields(
|
DASH.replace_fields(
|
||||||
segment_template.get("media"),
|
segment_template.get("media"),
|
||||||
Bandwidth=representation.get("bandwidth"),
|
Bandwidth=representation.get("bandwidth"),
|
||||||
Number=n,
|
Number=n,
|
||||||
RepresentationID=representation.get("id"),
|
RepresentationID=representation.get("id"),
|
||||||
Time=t
|
Time=t
|
||||||
)
|
), None
|
||||||
for t, n in zip(seg_time_list, seg_num_list)
|
))
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
if not period_duration:
|
if not period_duration:
|
||||||
raise ValueError("Duration of the Period was unable to be determined.")
|
raise ValueError("Duration of the Period was unable to be determined.")
|
||||||
period_duration = DASH.pt_to_sec(period_duration)
|
period_duration = DASH.pt_to_sec(period_duration)
|
||||||
segment_duration = float(segment_template.get("duration"))
|
segment_duration = float(segment_template.get("duration"))
|
||||||
segment_timescale = float(segment_template.get("timescale") or 1)
|
segment_timescale = float(segment_template.get("timescale") or 1)
|
||||||
|
|
||||||
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
|
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
|
||||||
segment_urls += [
|
|
||||||
|
for s in range(start_number, start_number + total_segments):
|
||||||
|
segments.append((
|
||||||
DASH.replace_fields(
|
DASH.replace_fields(
|
||||||
segment_template.get("media"),
|
segment_template.get("media"),
|
||||||
Bandwidth=representation.get("bandwidth"),
|
Bandwidth=representation.get("bandwidth"),
|
||||||
Number=s,
|
Number=s,
|
||||||
RepresentationID=representation.get("id"),
|
RepresentationID=representation.get("id"),
|
||||||
Time=s
|
Time=s
|
||||||
)
|
), None
|
||||||
for s in range(start_number, start_number + total_segments)
|
|
||||||
]
|
|
||||||
|
|
||||||
init_data = None
|
|
||||||
init_url = segment_template.get("initialization")
|
|
||||||
if init_url:
|
|
||||||
res = session.get(DASH.replace_fields(
|
|
||||||
init_url,
|
|
||||||
Bandwidth=representation.get("bandwidth"),
|
|
||||||
RepresentationID=representation.get("id")
|
|
||||||
))
|
))
|
||||||
res.raise_for_status()
|
|
||||||
init_data = res.content
|
|
||||||
if not drm:
|
|
||||||
try:
|
|
||||||
drm = Widevine.from_init_data(init_data)
|
|
||||||
except Widevine.Exceptions.PSSHNotFound:
|
|
||||||
# it might not have Widevine DRM, or might not have found the PSSH
|
|
||||||
log.warning("No Widevine PSSH was found for this track, is it DRM free?")
|
|
||||||
else:
|
|
||||||
# license and grab content keys
|
|
||||||
if not license_widevine:
|
|
||||||
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
|
|
||||||
license_widevine(drm)
|
|
||||||
|
|
||||||
for i, segment_url in enumerate(tqdm(segment_urls, unit="segments")):
|
|
||||||
segment_filename = str(i).zfill(len(str(len(segment_urls))))
|
|
||||||
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
|
|
||||||
|
|
||||||
asyncio.run(aria2c(
|
|
||||||
segment_url,
|
|
||||||
segment_save_path,
|
|
||||||
session.headers,
|
|
||||||
proxy
|
|
||||||
))
|
|
||||||
|
|
||||||
if isinstance(track, Audio) or init_data:
|
|
||||||
with open(segment_save_path, "rb+") as f:
|
|
||||||
segment_data = f.read()
|
|
||||||
if isinstance(track, Audio):
|
|
||||||
# fix audio decryption on ATVP by fixing the sample description index
|
|
||||||
# TODO: Is this in mpeg data, or init data?
|
|
||||||
segment_data = re.sub(
|
|
||||||
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
|
|
||||||
b"\\g<1>\x01",
|
|
||||||
segment_data
|
|
||||||
)
|
|
||||||
# prepend the init data to be able to decrypt
|
|
||||||
if init_data:
|
|
||||||
f.seek(0)
|
|
||||||
f.write(init_data)
|
|
||||||
f.write(segment_data)
|
|
||||||
|
|
||||||
if drm:
|
|
||||||
# TODO: What if the manifest does not mention DRM, but has DRM
|
|
||||||
drm.decrypt(segment_save_path)
|
|
||||||
track.drm = None
|
|
||||||
if callable(track.OnDecrypted):
|
|
||||||
track.OnDecrypted(track)
|
|
||||||
elif segment_list is not None:
|
elif segment_list is not None:
|
||||||
base_media_url = urljoin(period_base_url, base_url)
|
base_media_url = urljoin(period_base_url, base_url)
|
||||||
if any(x.get("media") is not None for x in segment_list.findall("SegmentURL")):
|
|
||||||
# at least one segment has no URL specified, it uses the base url and ranges
|
|
||||||
track.url = base_media_url
|
|
||||||
track.descriptor = track.Descriptor.URL
|
|
||||||
track.drm = [drm] if drm else []
|
|
||||||
else:
|
|
||||||
init_data = None
|
init_data = None
|
||||||
initialization = segment_list.find("Initialization")
|
initialization = segment_list.find("Initialization")
|
||||||
if initialization:
|
if initialization:
|
||||||
|
@ -456,7 +416,23 @@ class DASH:
|
||||||
res = session.get(source_url)
|
res = session.get(source_url)
|
||||||
res.raise_for_status()
|
res.raise_for_status()
|
||||||
init_data = res.content
|
init_data = res.content
|
||||||
if not drm:
|
|
||||||
|
segment_urls = segment_list.findall("SegmentURL")
|
||||||
|
for segment_url in segment_urls:
|
||||||
|
media_url = segment_url.get("media")
|
||||||
|
if media_url is None:
|
||||||
|
media_url = base_media_url
|
||||||
|
|
||||||
|
segments.append((
|
||||||
|
media_url,
|
||||||
|
segment_url.get("mediaRange")
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
log.error("Could not find a way to get segments from this MPD manifest.")
|
||||||
|
log.debug(manifest_url)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
if not drm and isinstance(track, (Video, Audio)):
|
||||||
try:
|
try:
|
||||||
drm = Widevine.from_init_data(init_data)
|
drm = Widevine.from_init_data(init_data)
|
||||||
except Widevine.Exceptions.PSSHNotFound:
|
except Widevine.Exceptions.PSSHNotFound:
|
||||||
|
@ -468,20 +444,23 @@ class DASH:
|
||||||
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
|
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
|
||||||
license_widevine(drm)
|
license_widevine(drm)
|
||||||
|
|
||||||
for i, segment_url in enumerate(tqdm(segment_list.findall("SegmentURL"), unit="segments")):
|
state_event = Event()
|
||||||
segment_filename = str(i).zfill(len(str(len(segment_urls))))
|
|
||||||
segment_save_path = (save_dir / segment_filename).with_suffix(".mp4")
|
|
||||||
|
|
||||||
media_url = segment_url.get("media")
|
def download_segment(filename: str, segment: tuple[str, Optional[str]]):
|
||||||
if media_url is None:
|
time.sleep(0.1)
|
||||||
media_url = base_media_url
|
if state_event.is_set():
|
||||||
|
return
|
||||||
|
|
||||||
|
segment_save_path = (save_dir / filename).with_suffix(".mp4")
|
||||||
|
|
||||||
|
segment_uri, segment_range = segment
|
||||||
|
|
||||||
asyncio.run(aria2c(
|
asyncio.run(aria2c(
|
||||||
media_url,
|
segment_uri,
|
||||||
segment_save_path,
|
segment_save_path,
|
||||||
session.headers,
|
session.headers,
|
||||||
proxy,
|
proxy,
|
||||||
byte_range=segment_url.get("mediaRange")
|
byte_range=segment_range
|
||||||
))
|
))
|
||||||
|
|
||||||
if isinstance(track, Audio) or init_data:
|
if isinstance(track, Audio) or init_data:
|
||||||
|
@ -507,15 +486,34 @@ class DASH:
|
||||||
track.drm = None
|
track.drm = None
|
||||||
if callable(track.OnDecrypted):
|
if callable(track.OnDecrypted):
|
||||||
track.OnDecrypted(track)
|
track.OnDecrypted(track)
|
||||||
elif segment_base is not None or base_url:
|
|
||||||
# SegmentBase more or less boils down to defined ByteRanges
|
with tqdm(total=len(segments), unit="segments") as pbar:
|
||||||
# So, we don't care, just download the full file
|
with ThreadPoolExecutor(max_workers=16) as pool:
|
||||||
track.url = urljoin(period_base_url, base_url)
|
try:
|
||||||
track.descriptor = track.Descriptor.URL
|
for download in futures.as_completed((
|
||||||
track.drm = [drm] if drm else []
|
pool.submit(
|
||||||
else:
|
download_segment,
|
||||||
log.error("Could not find a way to get segments from this MPD manifest.")
|
filename=str(i).zfill(len(str(len(segments)))),
|
||||||
|
segment=segment
|
||||||
|
)
|
||||||
|
for i, segment in enumerate(segments)
|
||||||
|
)):
|
||||||
|
if download.cancelled():
|
||||||
|
continue
|
||||||
|
e = download.exception()
|
||||||
|
if e:
|
||||||
|
state_event.set()
|
||||||
|
pool.shutdown(wait=False, cancel_futures=True)
|
||||||
|
traceback.print_exception(e)
|
||||||
|
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
pbar.update(1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
state_event.set()
|
||||||
|
pool.shutdown(wait=False, cancel_futures=True)
|
||||||
|
log.info("Received Keyboard Interrupt, stopping...")
|
||||||
|
return
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_language(*options: Any) -> Optional[Language]:
|
def get_language(*options: Any) -> Optional[Language]:
|
||||||
|
|
Loading…
Reference in New Issue