Rework the HLS downloader, add support for new downloaders

- It now downloads all segment files multi-threaded first before any decryption or merging operations (excluding init data, which will be downloaded in sequence/order after all the segments are downloaded)
- Once all segments are downloaded it then starts to go through and do any merging/decryption/init data stuff/e.t.c afterwards.
- Segments are no longer decrypted one by one. If segments use the same EXT-X-KEY data, then they will be merged together and then decrypted. This should see a noticeable speed increase for Widevine DRM.
This commit is contained in:
rlaphoenix 2024-02-15 11:15:20 +00:00
parent e5a330df7e
commit 2b7fc929f6
1 changed files with 307 additions and 289 deletions

View File

@ -2,17 +2,11 @@ from __future__ import annotations
import html
import logging
import re
import shutil
import subprocess
import sys
import time
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from pathlib import Path
from queue import Queue
from threading import Lock
from typing import Any, Callable, Optional, Union
from urllib.parse import urljoin
from zlib import crc32
@ -24,7 +18,6 @@ from m3u8 import M3U8
from pywidevine.cdm import Cdm as WidevineCdm
from pywidevine.pssh import PSSH
from requests import Session
from rich import filesize
from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack
from devine.core.downloaders import downloader
@ -95,7 +88,7 @@ class HLS:
All Track objects' URL will be to another M3U(8) document. However, these documents
will be Invariant Playlists and contain the list of segments URIs among other metadata.
"""
session_drm = HLS.get_drm(self.manifest.session_keys)
session_drm = HLS.get_all_drm(self.manifest.session_keys)
audio_codecs_by_group_id: dict[str, Audio.Codec] = {}
tracks = Tracks()
@ -238,114 +231,225 @@ class HLS:
else:
session_drm = None
progress(total=len(master.segments))
segments = [
segment for segment in master.segments
if not callable(track.OnSegmentFilter) or not track.OnSegmentFilter(segment)
]
download_sizes = []
download_speed_window = 5
last_speed_refresh = time.time()
total_segments = len(segments)
progress(total=total_segments)
segment_key = Queue(maxsize=1)
segment_key.put((session_drm, None))
init_data = Queue(maxsize=1)
init_data.put(None)
range_offset = Queue(maxsize=1)
range_offset.put(0)
drm_lock = Lock()
downloader_ = downloader
discontinuities: list[list[segment]] = []
discontinuity_index = -1
for i, segment in enumerate(master.segments):
if i == 0 or segment.discontinuity:
discontinuity_index += 1
discontinuities.append([])
discontinuities[discontinuity_index].append(segment)
urls: list[dict[str, Any]] = []
range_offset = 0
for segment in segments:
if segment.byterange:
if downloader_.__name__ == "aria2c":
# aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader
downloader_ = requests_downloader
byte_range = HLS.calculate_byte_range(segment.byterange, range_offset)
range_offset = byte_range.split("-")[0]
else:
byte_range = None
urls.append({
"url": urljoin(segment.base_uri, segment.uri),
"headers": {
"Range": f"bytes={byte_range}"
} if byte_range else {}
})
for d_i, discontinuity in enumerate(discontinuities):
# each discontinuity is a separate 'file'/encode and must be processed separately
discontinuity_save_dir = save_dir / str(d_i).zfill(len(str(len(discontinuities))))
discontinuity_save_path = discontinuity_save_dir.with_suffix(Path(discontinuity[0].uri).suffix)
segment_save_dir = save_dir / "segments"
with ThreadPoolExecutor(max_workers=16) as pool:
for i, download in enumerate(futures.as_completed((
pool.submit(
HLS.download_segment,
segment=segment,
out_path=(
discontinuity_save_dir /
str(s_i).zfill(len(str(len(discontinuity))))
).with_suffix(Path(segment.uri).suffix),
track=track,
init_data=init_data,
segment_key=segment_key,
range_offset=range_offset,
drm_lock=drm_lock,
progress=progress,
license_widevine=license_widevine,
session=session,
proxy=proxy
for status_update in downloader_(
urls=urls,
output_dir=segment_save_dir,
filename="{i:0%d}{ext}" % len(str(len(segments))),
headers=session.headers,
cookies=session.cookies,
proxy=proxy,
max_workers=16
):
file_downloaded = status_update.get("file_downloaded")
if file_downloaded and callable(track.OnSegmentDownloaded):
track.OnSegmentDownloaded(file_downloaded)
else:
downloaded = status_update.get("downloaded")
if downloaded and downloaded.endswith("/s"):
status_update["downloaded"] = f"HLS {downloaded}"
progress(**status_update)
progress(total=total_segments, completed=0, downloaded="Merging")
discon_i = 0
range_offset = 0
map_data: Optional[tuple[m3u8.model.InitializationSection, bytes]] = None
if session_drm:
encryption_data: Optional[tuple[int, Optional[m3u8.Key], DRM_T]] = (0, None, session_drm)
else:
encryption_data: Optional[tuple[int, Optional[m3u8.Key], DRM_T]] = None
for i, segment in enumerate(segments):
is_last_segment = (i + 1) == total_segments
name_len = len(str(total_segments))
segment_file_ext = Path(segment.uri).suffix
segment_file_path = segment_save_dir / f"{str(i).zfill(name_len)}{segment_file_ext}"
def merge(to: Path, via: list[Path], delete: bool = False, include_map_data: bool = False):
"""
Merge all files to a given path, optionally including map data.
Parameters:
to: The output file with all merged data.
via: List of files to merge, in sequence.
delete: Delete the file once it's been merged.
include_map_data: Whether to include the init map data.
"""
with open(to, "wb") as x:
if include_map_data and map_data:
x.write(map_data[1])
for file in via:
x.write(file.read_bytes())
if delete:
file.unlink()
def decrypt(include_this_segment: bool) -> Path:
"""
Decrypt all segments that uses the currently set DRM.
All segments that will be decrypted with this DRM will be merged together
in sequence, prefixed with the init data (if any), and then deleted. Once
merged they will be decrypted. The merged and decrypted file names state
the range of segments that were used.
Parameters:
include_this_segment: Whether to include the current segment in the
list of segments to merge and decrypt. This should be False if
decrypting on EXT-X-KEY changes, or True when decrypting on the
last segment.
Returns the decrypted path.
"""
drm = encryption_data[2]
first_segment_i = encryption_data[0]
last_segment_i = i - int(not include_this_segment)
segment_range = f"{str(first_segment_i).zfill(name_len)}-{str(last_segment_i).zfill(name_len)}"
merged_path = segment_save_dir / f"{segment_range}{Path(segments[last_segment_i].uri).suffix}"
decrypted_path = segment_save_dir / f"{merged_path.stem}_decrypted{merged_path.suffix}"
merge(
to=merged_path,
via=[
file
for file in sorted(segment_save_dir.iterdir())
if file.stem.isdigit() and first_segment_i <= int(file.stem) <= last_segment_i
],
delete=True,
include_map_data=True
)
drm.decrypt(merged_path)
merged_path.rename(decrypted_path)
if callable(track.OnDecrypted):
track.OnDecrypted(drm, decrypted_path)
return decrypted_path
def merge_discontinuity():
"""Merge all files in the segment save directory so far."""
files = list(sorted(segment_save_dir.iterdir()))
to_dir = segment_save_dir.parent
to_path = to_dir / f"{str(discon_i).zfill(name_len)}{files[-1].suffix}"
merge(
to=to_path,
via=files,
delete=True,
include_map_data=True
)
segment_save_dir.rmdir()
if isinstance(track, Subtitle):
segment_data = try_ensure_utf8(segment_file_path.read_bytes())
if track.codec not in (Subtitle.Codec.fVTT, Subtitle.Codec.fTTML):
# decode text direction entities or SubtitleEdit's /ReverseRtlStartEnd won't work
segment_data = segment_data.decode("utf8"). \
replace("&lrm;", html.unescape("&lrm;")). \
replace("&rlm;", html.unescape("&rlm;")). \
encode("utf8")
segment_file_path.write_bytes(segment_data)
if segment.discontinuity:
if encryption_data:
decrypt(include_this_segment=False)
merge_discontinuity()
discon_i += 1
range_offset = 0 # TODO: Should this be reset or not?
map_data = None
encryption_data = None # TODO: Should this be reset or not?
if segment.init_section and (not map_data or segment.init_section != map_data[0]):
if segment.init_section.byterange:
init_byte_range = HLS.calculate_byte_range(
segment.init_section.byterange,
range_offset
)
for s_i, segment in enumerate(discontinuity)
))):
try:
download_size = download.result()
except KeyboardInterrupt:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
progress(downloaded="[yellow]CANCELLING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[yellow]CANCELLED")
# tell dl that it was cancelled
# the pool is already shut down, so exiting loop is fine
raise
except Exception as e:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
progress(downloaded="[red]FAILING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[red]FAILED")
# tell dl that it failed
# the pool is already shut down, so exiting loop is fine
raise e
else:
# it successfully downloaded, and it was not cancelled
progress(advance=1)
range_offset = init_byte_range.split("-")[0]
init_range_header = {
"Range": f"bytes={init_byte_range}"
}
else:
init_range_header = {}
if download_size == -1: # skipped for --skip-dl
progress(downloaded="[yellow]SKIPPING")
continue
res = session.get(
url=urljoin(segment.init_section.base_uri, segment.init_section.uri),
headers=init_range_header
)
res.raise_for_status()
map_data = (segment.init_section, res.content)
now = time.time()
time_since = now - last_speed_refresh
if segment.keys:
key = HLS.get_supported_key(segment.keys)
if encryption_data and encryption_data[1] != key:
decrypt(include_this_segment=False)
if download_size: # no size == skipped dl
download_sizes.append(download_size)
if key is None:
encryption_data = None
elif not encryption_data or encryption_data[1] != key:
drm = HLS.get_drm(key, proxy)
if isinstance(drm, Widevine):
if map_data:
track_kid = track.get_key_id(map_data[1])
else:
track_kid = None
progress(downloaded="LICENSING")
license_widevine(drm, track_kid=track_kid)
progress(downloaded="[yellow]LICENSED")
encryption_data = (i, key, drm)
if download_sizes and (time_since > download_speed_window or i == len(master.segments)):
data_size = sum(download_sizes)
download_speed = data_size / (time_since or 1)
progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
# TODO: This wont work as we already downloaded
if DOWNLOAD_LICENCE_ONLY.is_set():
continue
if discontinuity_save_dir.exists():
with open(discontinuity_save_path, "wb") as f:
for segment_file in sorted(discontinuity_save_dir.iterdir()):
segment_data = segment_file.read_bytes()
if isinstance(track, Subtitle):
segment_data = try_ensure_utf8(segment_data)
if track.codec not in (Subtitle.Codec.fVTT, Subtitle.Codec.fTTML):
# decode text direction entities or SubtitleEdit's /ReverseRtlStartEnd won't work
segment_data = segment_data.decode("utf8"). \
replace("&lrm;", html.unescape("&lrm;")). \
replace("&rlm;", html.unescape("&rlm;")). \
encode("utf8")
f.write(segment_data)
segment_file.unlink()
shutil.rmtree(discontinuity_save_dir)
if is_last_segment:
# required as it won't end with EXT-X-DISCONTINUITY nor a new key
if encryption_data:
decrypt(include_this_segment=True)
merge_discontinuity()
progress(advance=1)
# TODO: Again still wont work, we've already downloaded
if DOWNLOAD_LICENCE_ONLY.is_set():
return
# finally merge all the discontinuity save files together to the final path
progress(downloaded="Merging")
if isinstance(track, (Video, Audio)):
progress(downloaded="Merging")
HLS.merge_segments(
segments=sorted(list(save_dir.iterdir())),
save_path=save_path
@ -365,162 +469,6 @@ class HLS:
if callable(track.OnDownloaded):
track.OnDownloaded()
@staticmethod
def download_segment(
segment: m3u8.Segment,
out_path: Path,
track: AnyTrack,
init_data: Queue,
segment_key: Queue,
range_offset: Queue,
drm_lock: Lock,
progress: partial,
license_widevine: Optional[Callable] = None,
session: Optional[Session] = None,
proxy: Optional[str] = None
) -> int:
"""
Download (and Decrypt) an HLS Media Segment.
Note: Make sure all Queue objects passed are appropriately initialized with
a starting value or this function may get permanently stuck.
Parameters:
segment: The m3u8.Segment Object to Download.
out_path: Path to save the downloaded Segment file to.
track: The Track object of which this Segment is for. Currently used to fix an
invalid value in the TFHD box of Audio Tracks, for the OnSegmentFilter, and
for DRM-related operations like getting the Track ID and Decryption.
init_data: Queue for saving and loading the most recent init section data.
segment_key: Queue for saving and loading the most recent DRM object, and it's
adjacent Segment.Key object.
range_offset: Queue for saving and loading the most recent Segment Bytes Range.
drm_lock: Prevent more than one Download from doing anything DRM-related at the
same time. Make sure all calls to download_segment() use the same Lock object.
progress: Rich Progress bar to provide progress updates to.
license_widevine: Function used to license Widevine DRM objects. It must be passed
if the Segment's DRM uses Widevine.
proxy: Proxy URI to use when downloading the Segment file.
session: Python-Requests Session used when requesting init data.
Returns the file size of the downloaded Segment in bytes.
"""
if DOWNLOAD_CANCELLED.is_set():
raise KeyboardInterrupt()
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
return 0
# handle init section changes
newest_init_data = init_data.get()
try:
if segment.init_section and (not newest_init_data or segment.discontinuity):
# Only use the init data if there's no init data yet (e.g., start of file)
# or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
# Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
# be unnecessary and slow to re-download the init data each time.
if segment.init_section.byterange:
previous_range_offset = range_offset.get()
byte_range = HLS.calculate_byte_range(segment.init_section.byterange, previous_range_offset)
range_offset.put(byte_range.split("-")[0])
range_header = {
"Range": f"bytes={byte_range}"
}
else:
range_header = {}
res = session.get(
url=urljoin(segment.init_section.base_uri, segment.init_section.uri),
headers=range_header
)
res.raise_for_status()
newest_init_data = res.content
finally:
init_data.put(newest_init_data)
# handle segment key changes
with drm_lock:
newest_segment_key = segment_key.get()
try:
if segment.keys and newest_segment_key[1] != segment.keys:
drm = HLS.get_drm(
keys=segment.keys,
proxy=proxy
)
if drm:
track.drm = drm
# license and grab content keys
# TODO: What if we don't want to use the first DRM system?
drm = drm[0]
if isinstance(drm, Widevine):
track_kid = track.get_key_id(newest_init_data)
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
progress(downloaded="LICENSING")
license_widevine(drm, track_kid=track_kid)
progress(downloaded="[yellow]LICENSED")
newest_segment_key = (drm, segment.keys)
finally:
segment_key.put(newest_segment_key)
if DOWNLOAD_LICENCE_ONLY.is_set():
return -1
headers_ = session.headers
if segment.byterange:
# aria2(c) doesn't support byte ranges, use python-requests
downloader_ = requests_downloader
previous_range_offset = range_offset.get()
byte_range = HLS.calculate_byte_range(segment.byterange, previous_range_offset)
range_offset.put(byte_range.split("-")[0])
headers_["Range"] = f"bytes={byte_range}"
else:
downloader_ = downloader
downloader_(
uri=urljoin(segment.base_uri, segment.uri),
out=out_path,
headers=headers_,
cookies=session.cookies,
proxy=proxy,
segmented=True
)
if callable(track.OnSegmentDownloaded):
track.OnSegmentDownloaded(out_path)
download_size = out_path.stat().st_size
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Should this be done in the video data or the init data?
if isinstance(track, Audio):
with open(out_path, "rb+") as f:
segment_data = f.read()
fixed_segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
if fixed_segment_data != segment_data:
f.seek(0)
f.write(fixed_segment_data)
# prepend the init data to be able to decrypt
if newest_init_data:
with open(out_path, "rb+") as f:
segment_data = f.read()
f.seek(0)
f.write(newest_init_data)
f.write(segment_data)
# decrypt segment if encrypted
if newest_segment_key[0]:
newest_segment_key[0].decrypt(out_path)
track.drm = None
if callable(track.OnDecrypted):
track.OnDecrypted(newest_segment_key[0], segment)
return download_size
@staticmethod
def merge_segments(segments: list[Path], save_path: Path) -> int:
"""
@ -552,53 +500,123 @@ class HLS:
return save_path.stat().st_size
@staticmethod
def get_supported_key(keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]]) -> Optional[m3u8.Key]:
"""
Get a support Key System from a list of Key systems.
Note that the key systems are chosen in an opinionated order.
Returns None if one of the key systems is method=NONE, which means all segments
from hence forth should be treated as plain text until another key system is
encountered, unless it's also method=NONE.
Raises NotImplementedError if none of the key systems are supported.
"""
if any(key.method == "NONE" for key in keys):
return None
unsupported_systems = []
for key in keys:
if not key:
continue
# TODO: Add a way to specify which supported key system to use
# TODO: Add support for 'SAMPLE-AES', 'AES-CTR', 'AES-CBC', 'ClearKey'
# if encryption_data and encryption_data[0] == key:
# # no need to re-obtain the exact same encryption data
# break
elif key.method == "AES-128":
return key
# # TODO: Use a session instead of creating a new connection within
# encryption_data = (key, ClearKey.from_m3u_key(key, proxy))
# break
elif key.method == "ISO-23001-7":
return key
# encryption_data = (key, Widevine(
# pssh=PSSH.new(
# key_ids=[key.uri.split(",")[-1]],
# system_id=PSSH.SystemId.Widevine
# )
# ))
# break
elif key.keyformat and key.keyformat.lower() == WidevineCdm.urn:
return key
# encryption_data = (key, Widevine(
# pssh=PSSH(key.uri.split(",")[-1]),
# **key._extra_params # noqa
# ))
# break
else:
unsupported_systems.append(key.method + (f" ({key.keyformat})" if key.keyformat else ""))
else:
raise NotImplementedError(f"None of the key systems are supported: {', '.join(unsupported_systems)}")
@staticmethod
def get_drm(
key: Union[m3u8.model.SessionKey, m3u8.model.Key],
proxy: Optional[str] = None
) -> DRM_T:
"""
Convert HLS EXT-X-KEY data to an initialized DRM object.
Parameters:
key: m3u8 key system (EXT-X-KEY) object.
proxy: Optional proxy string used for requesting AES-128 URIs.
Raises a NotImplementedError if the key system is not supported.
"""
# TODO: Add support for 'SAMPLE-AES', 'AES-CTR', 'AES-CBC', 'ClearKey'
if key.method == "AES-128":
# TODO: Use a session instead of creating a new connection within
drm = ClearKey.from_m3u_key(key, proxy)
elif key.method == "ISO-23001-7":
drm = Widevine(
pssh=PSSH.new(
key_ids=[key.uri.split(",")[-1]],
system_id=PSSH.SystemId.Widevine
)
)
elif key.keyformat and key.keyformat.lower() == WidevineCdm.urn:
drm = Widevine(
pssh=PSSH(key.uri.split(",")[-1]),
**key._extra_params # noqa
)
else:
raise NotImplementedError(f"The key system is not supported: {key}")
return drm
@staticmethod
def get_all_drm(
keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]],
proxy: Optional[str] = None
) -> list[DRM_T]:
"""
Convert HLS EXT-X-KEY data to initialized DRM objects.
You can supply key data for a single segment or for the entire manifest.
This lets you narrow the results down to each specific segment's DRM status.
Parameters:
keys: m3u8 key system (EXT-X-KEY) objects.
proxy: Optional proxy string used for requesting AES-128 URIs.
Returns an empty list if there were no supplied EXT-X-KEY data, or if all the
EXT-X-KEY's were of blank data. An empty list signals a DRM-free stream or segment.
Will raise a NotImplementedError if EXT-X-KEY data was supplied and none of them
were supported. A DRM-free track will never raise NotImplementedError.
Raises a NotImplementedError if none of the key systems are supported.
"""
drm = []
unsupported_systems = []
unsupported_keys: list[m3u8.Key] = []
drm_objects: list[DRM_T] = []
if any(key.method == "NONE" for key in keys):
return []
for key in keys:
if not key:
continue
# TODO: Add support for 'SAMPLE-AES', 'AES-CTR', 'AES-CBC', 'ClearKey'
if key.method == "NONE":
return []
elif key.method == "AES-128":
drm.append(ClearKey.from_m3u_key(key, proxy))
elif key.method == "ISO-23001-7":
drm.append(Widevine(
pssh=PSSH.new(
key_ids=[key.uri.split(",")[-1]],
system_id=PSSH.SystemId.Widevine
)
))
elif key.keyformat and key.keyformat.lower() == WidevineCdm.urn:
drm.append(Widevine(
pssh=PSSH(key.uri.split(",")[-1]),
**key._extra_params # noqa
))
else:
unsupported_systems.append(key.method + (f" ({key.keyformat})" if key.keyformat else ""))
try:
drm = HLS.get_drm(key, proxy)
drm_objects.append(drm)
except NotImplementedError:
unsupported_keys.append(key)
if not drm and unsupported_systems:
raise NotImplementedError(f"No support for any of the key systems: {', '.join(unsupported_systems)}")
if not drm_objects and unsupported_keys:
raise NotImplementedError(f"None of the key systems are supported: {unsupported_keys}")
return drm
return drm_objects
@staticmethod
def calculate_byte_range(m3u_range: str, fallback_offset: int = 0) -> str: