diff --git a/devine/core/manifests/hls.py b/devine/core/manifests/hls.py index 6c1fe86..f9de82b 100644 --- a/devine/core/manifests/hls.py +++ b/devine/core/manifests/hls.py @@ -2,17 +2,11 @@ from __future__ import annotations import html import logging -import re import shutil import subprocess import sys -import time -from concurrent import futures -from concurrent.futures import ThreadPoolExecutor from functools import partial from pathlib import Path -from queue import Queue -from threading import Lock from typing import Any, Callable, Optional, Union from urllib.parse import urljoin from zlib import crc32 @@ -24,7 +18,6 @@ from m3u8 import M3U8 from pywidevine.cdm import Cdm as WidevineCdm from pywidevine.pssh import PSSH from requests import Session -from rich import filesize from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack from devine.core.downloaders import downloader @@ -95,7 +88,7 @@ class HLS: All Track objects' URL will be to another M3U(8) document. However, these documents will be Invariant Playlists and contain the list of segments URIs among other metadata. """ - session_drm = HLS.get_drm(self.manifest.session_keys) + session_drm = HLS.get_all_drm(self.manifest.session_keys) audio_codecs_by_group_id: dict[str, Audio.Codec] = {} tracks = Tracks() @@ -238,114 +231,225 @@ class HLS: else: session_drm = None - progress(total=len(master.segments)) + segments = [ + segment for segment in master.segments + if not callable(track.OnSegmentFilter) or not track.OnSegmentFilter(segment) + ] - download_sizes = [] - download_speed_window = 5 - last_speed_refresh = time.time() + total_segments = len(segments) + progress(total=total_segments) - segment_key = Queue(maxsize=1) - segment_key.put((session_drm, None)) - init_data = Queue(maxsize=1) - init_data.put(None) - range_offset = Queue(maxsize=1) - range_offset.put(0) - drm_lock = Lock() + downloader_ = downloader - discontinuities: list[list[segment]] = [] - discontinuity_index = -1 - for i, segment in enumerate(master.segments): - if i == 0 or segment.discontinuity: - discontinuity_index += 1 - discontinuities.append([]) - discontinuities[discontinuity_index].append(segment) + urls: list[dict[str, Any]] = [] + range_offset = 0 + for segment in segments: + if segment.byterange: + if downloader_.__name__ == "aria2c": + # aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader + downloader_ = requests_downloader + byte_range = HLS.calculate_byte_range(segment.byterange, range_offset) + range_offset = byte_range.split("-")[0] + else: + byte_range = None + urls.append({ + "url": urljoin(segment.base_uri, segment.uri), + "headers": { + "Range": f"bytes={byte_range}" + } if byte_range else {} + }) - for d_i, discontinuity in enumerate(discontinuities): - # each discontinuity is a separate 'file'/encode and must be processed separately - discontinuity_save_dir = save_dir / str(d_i).zfill(len(str(len(discontinuities)))) - discontinuity_save_path = discontinuity_save_dir.with_suffix(Path(discontinuity[0].uri).suffix) + segment_save_dir = save_dir / "segments" - with ThreadPoolExecutor(max_workers=16) as pool: - for i, download in enumerate(futures.as_completed(( - pool.submit( - HLS.download_segment, - segment=segment, - out_path=( - discontinuity_save_dir / - str(s_i).zfill(len(str(len(discontinuity)))) - ).with_suffix(Path(segment.uri).suffix), - track=track, - init_data=init_data, - segment_key=segment_key, - range_offset=range_offset, - drm_lock=drm_lock, - progress=progress, - license_widevine=license_widevine, - session=session, - proxy=proxy + for status_update in downloader_( + urls=urls, + output_dir=segment_save_dir, + filename="{i:0%d}{ext}" % len(str(len(segments))), + headers=session.headers, + cookies=session.cookies, + proxy=proxy, + max_workers=16 + ): + file_downloaded = status_update.get("file_downloaded") + if file_downloaded and callable(track.OnSegmentDownloaded): + track.OnSegmentDownloaded(file_downloaded) + else: + downloaded = status_update.get("downloaded") + if downloaded and downloaded.endswith("/s"): + status_update["downloaded"] = f"HLS {downloaded}" + progress(**status_update) + + progress(total=total_segments, completed=0, downloaded="Merging") + + discon_i = 0 + range_offset = 0 + map_data: Optional[tuple[m3u8.model.InitializationSection, bytes]] = None + if session_drm: + encryption_data: Optional[tuple[int, Optional[m3u8.Key], DRM_T]] = (0, None, session_drm) + else: + encryption_data: Optional[tuple[int, Optional[m3u8.Key], DRM_T]] = None + + for i, segment in enumerate(segments): + is_last_segment = (i + 1) == total_segments + name_len = len(str(total_segments)) + segment_file_ext = Path(segment.uri).suffix + segment_file_path = segment_save_dir / f"{str(i).zfill(name_len)}{segment_file_ext}" + + def merge(to: Path, via: list[Path], delete: bool = False, include_map_data: bool = False): + """ + Merge all files to a given path, optionally including map data. + + Parameters: + to: The output file with all merged data. + via: List of files to merge, in sequence. + delete: Delete the file once it's been merged. + include_map_data: Whether to include the init map data. + """ + with open(to, "wb") as x: + if include_map_data and map_data: + x.write(map_data[1]) + for file in via: + x.write(file.read_bytes()) + if delete: + file.unlink() + + def decrypt(include_this_segment: bool) -> Path: + """ + Decrypt all segments that uses the currently set DRM. + + All segments that will be decrypted with this DRM will be merged together + in sequence, prefixed with the init data (if any), and then deleted. Once + merged they will be decrypted. The merged and decrypted file names state + the range of segments that were used. + + Parameters: + include_this_segment: Whether to include the current segment in the + list of segments to merge and decrypt. This should be False if + decrypting on EXT-X-KEY changes, or True when decrypting on the + last segment. + + Returns the decrypted path. + """ + drm = encryption_data[2] + first_segment_i = encryption_data[0] + last_segment_i = i - int(not include_this_segment) + + segment_range = f"{str(first_segment_i).zfill(name_len)}-{str(last_segment_i).zfill(name_len)}" + merged_path = segment_save_dir / f"{segment_range}{Path(segments[last_segment_i].uri).suffix}" + decrypted_path = segment_save_dir / f"{merged_path.stem}_decrypted{merged_path.suffix}" + + merge( + to=merged_path, + via=[ + file + for file in sorted(segment_save_dir.iterdir()) + if file.stem.isdigit() and first_segment_i <= int(file.stem) <= last_segment_i + ], + delete=True, + include_map_data=True + ) + + drm.decrypt(merged_path) + merged_path.rename(decrypted_path) + + if callable(track.OnDecrypted): + track.OnDecrypted(drm, decrypted_path) + + return decrypted_path + + def merge_discontinuity(): + """Merge all files in the segment save directory so far.""" + files = list(sorted(segment_save_dir.iterdir())) + + to_dir = segment_save_dir.parent + to_path = to_dir / f"{str(discon_i).zfill(name_len)}{files[-1].suffix}" + + merge( + to=to_path, + via=files, + delete=True, + include_map_data=True + ) + segment_save_dir.rmdir() + + if isinstance(track, Subtitle): + segment_data = try_ensure_utf8(segment_file_path.read_bytes()) + if track.codec not in (Subtitle.Codec.fVTT, Subtitle.Codec.fTTML): + # decode text direction entities or SubtitleEdit's /ReverseRtlStartEnd won't work + segment_data = segment_data.decode("utf8"). \ + replace("‎", html.unescape("‎")). \ + replace("‏", html.unescape("‏")). \ + encode("utf8") + segment_file_path.write_bytes(segment_data) + + if segment.discontinuity: + if encryption_data: + decrypt(include_this_segment=False) + merge_discontinuity() + + discon_i += 1 + range_offset = 0 # TODO: Should this be reset or not? + map_data = None + encryption_data = None # TODO: Should this be reset or not? + + if segment.init_section and (not map_data or segment.init_section != map_data[0]): + if segment.init_section.byterange: + init_byte_range = HLS.calculate_byte_range( + segment.init_section.byterange, + range_offset ) - for s_i, segment in enumerate(discontinuity) - ))): - try: - download_size = download.result() - except KeyboardInterrupt: - DOWNLOAD_CANCELLED.set() # skip pending track downloads - progress(downloaded="[yellow]CANCELLING") - pool.shutdown(wait=True, cancel_futures=True) - progress(downloaded="[yellow]CANCELLED") - # tell dl that it was cancelled - # the pool is already shut down, so exiting loop is fine - raise - except Exception as e: - DOWNLOAD_CANCELLED.set() # skip pending track downloads - progress(downloaded="[red]FAILING") - pool.shutdown(wait=True, cancel_futures=True) - progress(downloaded="[red]FAILED") - # tell dl that it failed - # the pool is already shut down, so exiting loop is fine - raise e - else: - # it successfully downloaded, and it was not cancelled - progress(advance=1) + range_offset = init_byte_range.split("-")[0] + init_range_header = { + "Range": f"bytes={init_byte_range}" + } + else: + init_range_header = {} - if download_size == -1: # skipped for --skip-dl - progress(downloaded="[yellow]SKIPPING") - continue + res = session.get( + url=urljoin(segment.init_section.base_uri, segment.init_section.uri), + headers=init_range_header + ) + res.raise_for_status() + map_data = (segment.init_section, res.content) - now = time.time() - time_since = now - last_speed_refresh + if segment.keys: + key = HLS.get_supported_key(segment.keys) + if encryption_data and encryption_data[1] != key: + decrypt(include_this_segment=False) - if download_size: # no size == skipped dl - download_sizes.append(download_size) + if key is None: + encryption_data = None + elif not encryption_data or encryption_data[1] != key: + drm = HLS.get_drm(key, proxy) + if isinstance(drm, Widevine): + if map_data: + track_kid = track.get_key_id(map_data[1]) + else: + track_kid = None + progress(downloaded="LICENSING") + license_widevine(drm, track_kid=track_kid) + progress(downloaded="[yellow]LICENSED") + encryption_data = (i, key, drm) - if download_sizes and (time_since > download_speed_window or i == len(master.segments)): - data_size = sum(download_sizes) - download_speed = data_size / (time_since or 1) - progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s") - last_speed_refresh = now - download_sizes.clear() + # TODO: This wont work as we already downloaded + if DOWNLOAD_LICENCE_ONLY.is_set(): + continue - if discontinuity_save_dir.exists(): - with open(discontinuity_save_path, "wb") as f: - for segment_file in sorted(discontinuity_save_dir.iterdir()): - segment_data = segment_file.read_bytes() - if isinstance(track, Subtitle): - segment_data = try_ensure_utf8(segment_data) - if track.codec not in (Subtitle.Codec.fVTT, Subtitle.Codec.fTTML): - # decode text direction entities or SubtitleEdit's /ReverseRtlStartEnd won't work - segment_data = segment_data.decode("utf8"). \ - replace("‎", html.unescape("‎")). \ - replace("‏", html.unescape("‏")). \ - encode("utf8") - f.write(segment_data) - segment_file.unlink() - shutil.rmtree(discontinuity_save_dir) + if is_last_segment: + # required as it won't end with EXT-X-DISCONTINUITY nor a new key + if encryption_data: + decrypt(include_this_segment=True) + merge_discontinuity() + progress(advance=1) + + # TODO: Again still wont work, we've already downloaded if DOWNLOAD_LICENCE_ONLY.is_set(): return + # finally merge all the discontinuity save files together to the final path + progress(downloaded="Merging") if isinstance(track, (Video, Audio)): - progress(downloaded="Merging") HLS.merge_segments( segments=sorted(list(save_dir.iterdir())), save_path=save_path @@ -365,162 +469,6 @@ class HLS: if callable(track.OnDownloaded): track.OnDownloaded() - @staticmethod - def download_segment( - segment: m3u8.Segment, - out_path: Path, - track: AnyTrack, - init_data: Queue, - segment_key: Queue, - range_offset: Queue, - drm_lock: Lock, - progress: partial, - license_widevine: Optional[Callable] = None, - session: Optional[Session] = None, - proxy: Optional[str] = None - ) -> int: - """ - Download (and Decrypt) an HLS Media Segment. - - Note: Make sure all Queue objects passed are appropriately initialized with - a starting value or this function may get permanently stuck. - - Parameters: - segment: The m3u8.Segment Object to Download. - out_path: Path to save the downloaded Segment file to. - track: The Track object of which this Segment is for. Currently used to fix an - invalid value in the TFHD box of Audio Tracks, for the OnSegmentFilter, and - for DRM-related operations like getting the Track ID and Decryption. - init_data: Queue for saving and loading the most recent init section data. - segment_key: Queue for saving and loading the most recent DRM object, and it's - adjacent Segment.Key object. - range_offset: Queue for saving and loading the most recent Segment Bytes Range. - drm_lock: Prevent more than one Download from doing anything DRM-related at the - same time. Make sure all calls to download_segment() use the same Lock object. - progress: Rich Progress bar to provide progress updates to. - license_widevine: Function used to license Widevine DRM objects. It must be passed - if the Segment's DRM uses Widevine. - proxy: Proxy URI to use when downloading the Segment file. - session: Python-Requests Session used when requesting init data. - - Returns the file size of the downloaded Segment in bytes. - """ - if DOWNLOAD_CANCELLED.is_set(): - raise KeyboardInterrupt() - - if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment): - return 0 - - # handle init section changes - newest_init_data = init_data.get() - try: - if segment.init_section and (not newest_init_data or segment.discontinuity): - # Only use the init data if there's no init data yet (e.g., start of file) - # or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP. - # Even if a new EXT-X-MAP is supplied, it may just be duplicate and would - # be unnecessary and slow to re-download the init data each time. - if segment.init_section.byterange: - previous_range_offset = range_offset.get() - byte_range = HLS.calculate_byte_range(segment.init_section.byterange, previous_range_offset) - range_offset.put(byte_range.split("-")[0]) - range_header = { - "Range": f"bytes={byte_range}" - } - else: - range_header = {} - res = session.get( - url=urljoin(segment.init_section.base_uri, segment.init_section.uri), - headers=range_header - ) - res.raise_for_status() - newest_init_data = res.content - finally: - init_data.put(newest_init_data) - - # handle segment key changes - with drm_lock: - newest_segment_key = segment_key.get() - try: - if segment.keys and newest_segment_key[1] != segment.keys: - drm = HLS.get_drm( - keys=segment.keys, - proxy=proxy - ) - if drm: - track.drm = drm - # license and grab content keys - # TODO: What if we don't want to use the first DRM system? - drm = drm[0] - if isinstance(drm, Widevine): - track_kid = track.get_key_id(newest_init_data) - if not license_widevine: - raise ValueError("license_widevine func must be supplied to use Widevine DRM") - progress(downloaded="LICENSING") - license_widevine(drm, track_kid=track_kid) - progress(downloaded="[yellow]LICENSED") - newest_segment_key = (drm, segment.keys) - finally: - segment_key.put(newest_segment_key) - - if DOWNLOAD_LICENCE_ONLY.is_set(): - return -1 - - headers_ = session.headers - if segment.byterange: - # aria2(c) doesn't support byte ranges, use python-requests - downloader_ = requests_downloader - previous_range_offset = range_offset.get() - byte_range = HLS.calculate_byte_range(segment.byterange, previous_range_offset) - range_offset.put(byte_range.split("-")[0]) - headers_["Range"] = f"bytes={byte_range}" - else: - downloader_ = downloader - - downloader_( - uri=urljoin(segment.base_uri, segment.uri), - out=out_path, - headers=headers_, - cookies=session.cookies, - proxy=proxy, - segmented=True - ) - - if callable(track.OnSegmentDownloaded): - track.OnSegmentDownloaded(out_path) - - download_size = out_path.stat().st_size - - # fix audio decryption on ATVP by fixing the sample description index - # TODO: Should this be done in the video data or the init data? - if isinstance(track, Audio): - with open(out_path, "rb+") as f: - segment_data = f.read() - fixed_segment_data = re.sub( - b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02", - b"\\g<1>\x01", - segment_data - ) - if fixed_segment_data != segment_data: - f.seek(0) - f.write(fixed_segment_data) - - # prepend the init data to be able to decrypt - if newest_init_data: - with open(out_path, "rb+") as f: - segment_data = f.read() - f.seek(0) - f.write(newest_init_data) - f.write(segment_data) - - # decrypt segment if encrypted - if newest_segment_key[0]: - newest_segment_key[0].decrypt(out_path) - track.drm = None - if callable(track.OnDecrypted): - track.OnDecrypted(newest_segment_key[0], segment) - - return download_size - @staticmethod def merge_segments(segments: list[Path], save_path: Path) -> int: """ @@ -552,53 +500,123 @@ class HLS: return save_path.stat().st_size + @staticmethod + def get_supported_key(keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]]) -> Optional[m3u8.Key]: + """ + Get a support Key System from a list of Key systems. + + Note that the key systems are chosen in an opinionated order. + + Returns None if one of the key systems is method=NONE, which means all segments + from hence forth should be treated as plain text until another key system is + encountered, unless it's also method=NONE. + + Raises NotImplementedError if none of the key systems are supported. + """ + if any(key.method == "NONE" for key in keys): + return None + + unsupported_systems = [] + for key in keys: + if not key: + continue + # TODO: Add a way to specify which supported key system to use + # TODO: Add support for 'SAMPLE-AES', 'AES-CTR', 'AES-CBC', 'ClearKey' + # if encryption_data and encryption_data[0] == key: + # # no need to re-obtain the exact same encryption data + # break + elif key.method == "AES-128": + return key + # # TODO: Use a session instead of creating a new connection within + # encryption_data = (key, ClearKey.from_m3u_key(key, proxy)) + # break + elif key.method == "ISO-23001-7": + return key + # encryption_data = (key, Widevine( + # pssh=PSSH.new( + # key_ids=[key.uri.split(",")[-1]], + # system_id=PSSH.SystemId.Widevine + # ) + # )) + # break + elif key.keyformat and key.keyformat.lower() == WidevineCdm.urn: + return key + # encryption_data = (key, Widevine( + # pssh=PSSH(key.uri.split(",")[-1]), + # **key._extra_params # noqa + # )) + # break + else: + unsupported_systems.append(key.method + (f" ({key.keyformat})" if key.keyformat else "")) + else: + raise NotImplementedError(f"None of the key systems are supported: {', '.join(unsupported_systems)}") + @staticmethod def get_drm( + key: Union[m3u8.model.SessionKey, m3u8.model.Key], + proxy: Optional[str] = None + ) -> DRM_T: + """ + Convert HLS EXT-X-KEY data to an initialized DRM object. + + Parameters: + key: m3u8 key system (EXT-X-KEY) object. + proxy: Optional proxy string used for requesting AES-128 URIs. + + Raises a NotImplementedError if the key system is not supported. + """ + # TODO: Add support for 'SAMPLE-AES', 'AES-CTR', 'AES-CBC', 'ClearKey' + if key.method == "AES-128": + # TODO: Use a session instead of creating a new connection within + drm = ClearKey.from_m3u_key(key, proxy) + elif key.method == "ISO-23001-7": + drm = Widevine( + pssh=PSSH.new( + key_ids=[key.uri.split(",")[-1]], + system_id=PSSH.SystemId.Widevine + ) + ) + elif key.keyformat and key.keyformat.lower() == WidevineCdm.urn: + drm = Widevine( + pssh=PSSH(key.uri.split(",")[-1]), + **key._extra_params # noqa + ) + else: + raise NotImplementedError(f"The key system is not supported: {key}") + + return drm + + @staticmethod + def get_all_drm( keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]], proxy: Optional[str] = None ) -> list[DRM_T]: """ Convert HLS EXT-X-KEY data to initialized DRM objects. - You can supply key data for a single segment or for the entire manifest. - This lets you narrow the results down to each specific segment's DRM status. + Parameters: + keys: m3u8 key system (EXT-X-KEY) objects. + proxy: Optional proxy string used for requesting AES-128 URIs. - Returns an empty list if there were no supplied EXT-X-KEY data, or if all the - EXT-X-KEY's were of blank data. An empty list signals a DRM-free stream or segment. - - Will raise a NotImplementedError if EXT-X-KEY data was supplied and none of them - were supported. A DRM-free track will never raise NotImplementedError. + Raises a NotImplementedError if none of the key systems are supported. """ - drm = [] - unsupported_systems = [] + unsupported_keys: list[m3u8.Key] = [] + drm_objects: list[DRM_T] = [] + + if any(key.method == "NONE" for key in keys): + return [] for key in keys: - if not key: - continue - # TODO: Add support for 'SAMPLE-AES', 'AES-CTR', 'AES-CBC', 'ClearKey' - if key.method == "NONE": - return [] - elif key.method == "AES-128": - drm.append(ClearKey.from_m3u_key(key, proxy)) - elif key.method == "ISO-23001-7": - drm.append(Widevine( - pssh=PSSH.new( - key_ids=[key.uri.split(",")[-1]], - system_id=PSSH.SystemId.Widevine - ) - )) - elif key.keyformat and key.keyformat.lower() == WidevineCdm.urn: - drm.append(Widevine( - pssh=PSSH(key.uri.split(",")[-1]), - **key._extra_params # noqa - )) - else: - unsupported_systems.append(key.method + (f" ({key.keyformat})" if key.keyformat else "")) + try: + drm = HLS.get_drm(key, proxy) + drm_objects.append(drm) + except NotImplementedError: + unsupported_keys.append(key) - if not drm and unsupported_systems: - raise NotImplementedError(f"No support for any of the key systems: {', '.join(unsupported_systems)}") + if not drm_objects and unsupported_keys: + raise NotImplementedError(f"None of the key systems are supported: {unsupported_keys}") - return drm + return drm_objects @staticmethod def calculate_byte_range(m3u_range: str, fallback_offset: int = 0) -> str: