Add support for the new Downloaders to DASH

This commit is contained in:
rlaphoenix 2024-02-15 11:13:14 +00:00
parent 0e96d18af6
commit a1ed083b74
1 changed files with 30 additions and 122 deletions

View File

@ -6,13 +6,10 @@ import logging
import math import math
import re import re
import sys import sys
import time
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from copy import copy from copy import copy
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Any, Callable, MutableMapping, Optional, Union from typing import Any, Callable, Optional, Union
from urllib.parse import urljoin, urlparse from urllib.parse import urljoin, urlparse
from uuid import UUID from uuid import UUID
from zlib import crc32 from zlib import crc32
@ -23,8 +20,6 @@ from lxml.etree import Element
from pywidevine.cdm import Cdm as WidevineCdm from pywidevine.cdm import Cdm as WidevineCdm
from pywidevine.pssh import PSSH from pywidevine.pssh import PSSH
from requests import Session from requests import Session
from requests.cookies import RequestsCookieJar
from rich import filesize
from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack
from devine.core.downloaders import downloader from devine.core.downloaders import downloader
@ -435,57 +430,36 @@ class DASH:
progress(total=len(segments)) progress(total=len(segments))
download_sizes = [] downloader_ = downloader
download_speed_window = 5 if downloader.__name__ == "aria2c" and any(bytes_range is not None for url, bytes_range in segments):
last_speed_refresh = time.time() # aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader
downloader_ = requests_downloader
with ThreadPoolExecutor(max_workers=16) as pool: for status_update in downloader_(
for i, download in enumerate(futures.as_completed(( urls=[
pool.submit( {
DASH.download_segment, "url": url,
url=url, "headers": {
out_path=(save_dir / str(n).zfill(len(str(len(segments))))).with_suffix(".mp4"), "Range": f"bytes={bytes_range}"
track=track, }
proxy=proxy, }
headers=session.headers, for url, bytes_range in segments
cookies=session.cookies, ],
bytes_range=bytes_range output_dir=save_dir,
) filename="{i:0%d}.mp4" % (len(str(len(segments)))),
for n, (url, bytes_range) in enumerate(segments) headers=session.headers,
))): cookies=session.cookies,
try: proxy=proxy,
download_size = download.result() max_workers=16
except KeyboardInterrupt: ):
DOWNLOAD_CANCELLED.set() # skip pending track downloads file_downloaded = status_update.get("file_downloaded")
progress(downloaded="[yellow]CANCELLING") if file_downloaded and callable(track.OnSegmentDownloaded):
pool.shutdown(wait=True, cancel_futures=True) track.OnSegmentDownloaded(file_downloaded)
progress(downloaded="[yellow]CANCELLED") else:
# tell dl that it was cancelled downloaded = status_update.get("downloaded")
# the pool is already shut down, so exiting loop is fine if downloaded and downloaded.endswith("/s"):
raise status_update["downloaded"] = f"DASH {downloaded}"
except Exception: progress(**status_update)
DOWNLOAD_CANCELLED.set() # skip pending track downloads
progress(downloaded="[red]FAILING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[red]FAILED")
# tell dl that it failed
# the pool is already shut down, so exiting loop is fine
raise
else:
progress(advance=1)
now = time.time()
time_since = now - last_speed_refresh
if download_size: # no size == skipped dl
download_sizes.append(download_size)
if download_sizes and (time_since > download_speed_window or i == len(segments)):
data_size = sum(download_sizes)
download_speed = data_size / (time_since or 1)
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
with open(save_path, "wb") as f: with open(save_path, "wb") as f:
if init_data: if init_data:
@ -518,72 +492,6 @@ class DASH:
progress(downloaded="Downloaded") progress(downloaded="Downloaded")
@staticmethod
def download_segment(
url: str,
out_path: Path,
track: AnyTrack,
proxy: Optional[str] = None,
headers: Optional[MutableMapping[str, str | bytes]] = None,
cookies: Optional[Union[MutableMapping[str, str], RequestsCookieJar]] = None,
bytes_range: Optional[str] = None
) -> int:
"""
Download a DASH Media Segment.
Parameters:
url: Full HTTP(S) URL to the Segment you want to download.
out_path: Path to save the downloaded Segment file to.
track: The Track object of which this Segment is for. Currently only used to
fix an invalid value in the TFHD box of Audio Tracks.
proxy: Proxy URI to use when downloading the Segment file.
headers: HTTP Headers to send when requesting the Segment file.
cookies: Cookies to send when requesting the Segment file. The actual cookies sent
will be resolved based on the URI among other parameters. Multiple cookies with
the same name but a different domain/path are resolved.
bytes_range: Download only specific bytes of the Segment file using the Range header.
Returns the file size of the downloaded Segment in bytes.
"""
if DOWNLOAD_CANCELLED.is_set():
raise KeyboardInterrupt()
if bytes_range:
# aria2(c) doesn't support byte ranges, use python-requests
downloader_ = requests_downloader
headers_ = dict(**headers, Range=f"bytes={bytes_range}")
else:
downloader_ = downloader
headers_ = headers
downloader_(
uri=url,
out=out_path,
headers=headers_,
cookies=cookies,
proxy=proxy,
segmented=True
)
if callable(track.OnSegmentDownloaded):
track.OnSegmentDownloaded(out_path)
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Should this be done in the video data or the init data?
if isinstance(track, Audio):
with open(out_path, "rb+") as f:
segment_data = f.read()
fixed_segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
if fixed_segment_data != segment_data:
f.seek(0)
f.write(fixed_segment_data)
return out_path.stat().st_size
@staticmethod @staticmethod
def _get( def _get(
item: str, item: str,