fix(HLS): Use filtered out segment key info

Also simplifies calculation of wanted segment range when decrypting. Instead of storing the starting segment index number with the encryption_data variable, we just grab the first segment that isn't already merged.

Fixes #77
This commit is contained in:
rlaphoenix 2024-03-04 12:51:00 +00:00
parent 499fc67ea0
commit 6e8efc3f63
1 changed files with 71 additions and 58 deletions

View File

@ -239,19 +239,22 @@ class HLS:
else: else:
session_drm = None session_drm = None
segments = [ unwanted_segments = [
segment for segment in master.segments segment for segment in master.segments
if not callable(track.OnSegmentFilter) or not track.OnSegmentFilter(segment) if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment)
] ]
total_segments = len(segments) total_segments = len(master.segments) - len(unwanted_segments)
progress(total=total_segments) progress(total=total_segments)
downloader_ = downloader downloader_ = downloader
urls: list[dict[str, Any]] = [] urls: list[dict[str, Any]] = []
range_offset = 0 range_offset = 0
for segment in segments: for segment in master.segments:
if segment in unwanted_segments:
continue
if segment.byterange: if segment.byterange:
if downloader_.__name__ == "aria2c": if downloader_.__name__ == "aria2c":
# aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader # aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader
@ -260,6 +263,7 @@ class HLS:
range_offset = byte_range.split("-")[0] range_offset = byte_range.split("-")[0]
else: else:
byte_range = None byte_range = None
urls.append({ urls.append({
"url": urljoin(segment.base_uri, segment.uri), "url": urljoin(segment.base_uri, segment.uri),
"headers": { "headers": {
@ -272,7 +276,7 @@ class HLS:
for status_update in downloader_( for status_update in downloader_(
urls=urls, urls=urls,
output_dir=segment_save_dir, output_dir=segment_save_dir,
filename="{i:0%d}{ext}" % len(str(len(segments))), filename="{i:0%d}{ext}" % len(str(len(urls))),
headers=session.headers, headers=session.headers,
cookies=session.cookies, cookies=session.cookies,
proxy=proxy, proxy=proxy,
@ -289,19 +293,21 @@ class HLS:
progress(total=total_segments, completed=0, downloaded="Merging") progress(total=total_segments, completed=0, downloaded="Merging")
name_len = len(str(total_segments))
discon_i = 0 discon_i = 0
range_offset = 0 range_offset = 0
map_data: Optional[tuple[m3u8.model.InitializationSection, bytes]] = None map_data: Optional[tuple[m3u8.model.InitializationSection, bytes]] = None
if session_drm: if session_drm:
encryption_data: Optional[tuple[int, Optional[m3u8.Key], DRM_T]] = (0, None, session_drm) encryption_data: Optional[tuple[Optional[m3u8.Key], DRM_T]] = (None, session_drm)
else: else:
encryption_data: Optional[tuple[int, Optional[m3u8.Key], DRM_T]] = None encryption_data: Optional[tuple[Optional[m3u8.Key], DRM_T]] = None
for i, segment in enumerate(segments): i = -1
is_last_segment = (i + 1) == total_segments for real_i, segment in enumerate(master.segments):
name_len = len(str(total_segments)) if segment not in unwanted_segments:
segment_file_ext = get_extension(segment.uri) i += 1
segment_file_path = segment_save_dir / f"{str(i).zfill(name_len)}{segment_file_ext}"
is_last_segment = (real_i + 1) == len(master.segments)
def merge(to: Path, via: list[Path], delete: bool = False, include_map_data: bool = False): def merge(to: Path, via: list[Path], delete: bool = False, include_map_data: bool = False):
""" """
@ -339,13 +345,17 @@ class HLS:
Returns the decrypted path. Returns the decrypted path.
""" """
drm = encryption_data[2] drm = encryption_data[1]
first_segment_i = encryption_data[0] first_segment_i = next(
int(file.stem)
for file in sorted(segment_save_dir.iterdir())
if file.stem.isdigit()
)
last_segment_i = max(0, i - int(not include_this_segment)) last_segment_i = max(0, i - int(not include_this_segment))
range_len = (last_segment_i - first_segment_i) + 1 range_len = (last_segment_i - first_segment_i) + 1
segment_range = f"{str(first_segment_i).zfill(name_len)}-{str(last_segment_i).zfill(name_len)}" segment_range = f"{str(first_segment_i).zfill(name_len)}-{str(last_segment_i).zfill(name_len)}"
merged_path = segment_save_dir / f"{segment_range}{get_extension(segments[last_segment_i].uri)}" merged_path = segment_save_dir / f"{segment_range}{get_extension(master.segments[last_segment_i].uri)}"
decrypted_path = segment_save_dir / f"{merged_path.stem}_decrypted{merged_path.suffix}" decrypted_path = segment_save_dir / f"{merged_path.stem}_decrypted{merged_path.suffix}"
files = [ files = [
@ -405,7 +415,10 @@ class HLS:
include_map_data=include_map_data include_map_data=include_map_data
) )
if segment not in unwanted_segments:
if isinstance(track, Subtitle): if isinstance(track, Subtitle):
segment_file_ext = get_extension(segment.uri)
segment_file_path = segment_save_dir / f"{str(i).zfill(name_len)}{segment_file_ext}"
segment_data = try_ensure_utf8(segment_file_path.read_bytes()) segment_data = try_ensure_utf8(segment_file_path.read_bytes())
if track.codec not in (Subtitle.Codec.fVTT, Subtitle.Codec.fTTML): if track.codec not in (Subtitle.Codec.fVTT, Subtitle.Codec.fTTML):
segment_data = segment_data.decode("utf8"). \ segment_data = segment_data.decode("utf8"). \
@ -419,14 +432,14 @@ class HLS:
decrypt(include_this_segment=False) decrypt(include_this_segment=False)
merge_discontinuity( merge_discontinuity(
include_this_segment=False, include_this_segment=False,
include_map_data=not encryption_data or not encryption_data[2] include_map_data=not encryption_data or not encryption_data[1]
) )
discon_i += 1 discon_i += 1
range_offset = 0 # TODO: Should this be reset or not? range_offset = 0 # TODO: Should this be reset or not?
map_data = None map_data = None
if encryption_data: if encryption_data:
encryption_data = (i, encryption_data[1], encryption_data[2]) encryption_data = (encryption_data[0], encryption_data[1])
if segment.init_section and (not map_data or segment.init_section != map_data[0]): if segment.init_section and (not map_data or segment.init_section != map_data[0]):
if segment.init_section.byterange: if segment.init_section.byterange:
@ -450,12 +463,12 @@ class HLS:
if segment.keys: if segment.keys:
key = HLS.get_supported_key(segment.keys) key = HLS.get_supported_key(segment.keys)
if encryption_data and encryption_data[1] != key and i != 0: if encryption_data and encryption_data[0] != key and i != 0 and segment not in unwanted_segments:
decrypt(include_this_segment=False) decrypt(include_this_segment=False)
if key is None: if key is None:
encryption_data = None encryption_data = None
elif not encryption_data or encryption_data[1] != key: elif not encryption_data or encryption_data[0] != key:
drm = HLS.get_drm(key, proxy) drm = HLS.get_drm(key, proxy)
if isinstance(drm, Widevine): if isinstance(drm, Widevine):
try: try:
@ -470,7 +483,7 @@ class HLS:
DOWNLOAD_CANCELLED.set() # skip pending track downloads DOWNLOAD_CANCELLED.set() # skip pending track downloads
progress(downloaded="[red]FAILED") progress(downloaded="[red]FAILED")
raise raise
encryption_data = (i, key, drm) encryption_data = (key, drm)
# TODO: This wont work as we already downloaded # TODO: This wont work as we already downloaded
if DOWNLOAD_LICENCE_ONLY.is_set(): if DOWNLOAD_LICENCE_ONLY.is_set():
@ -482,7 +495,7 @@ class HLS:
decrypt(include_this_segment=True) decrypt(include_this_segment=True)
merge_discontinuity( merge_discontinuity(
include_this_segment=True, include_this_segment=True,
include_map_data=not encryption_data or not encryption_data[2] include_map_data=not encryption_data or not encryption_data[1]
) )
progress(advance=1) progress(advance=1)