Merge pull request #67 from Shivelight/feature/fix-webvtt-timestamp

Correct timestamps when merging fragmented WebVTT
This commit is contained in:
rlaphoenix 2024-05-07 06:54:42 +01:00 committed by GitHub
commit 7aa797a4cc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 255 additions and 15 deletions

View File

@ -285,12 +285,15 @@ class DASH:
segment_base = adaptation_set.find("SegmentBase") segment_base = adaptation_set.find("SegmentBase")
segments: list[tuple[str, Optional[str]]] = [] segments: list[tuple[str, Optional[str]]] = []
segment_timescale: float = 0
segment_durations: list[int] = []
track_kid: Optional[UUID] = None track_kid: Optional[UUID] = None
if segment_template is not None: if segment_template is not None:
segment_template = copy(segment_template) segment_template = copy(segment_template)
start_number = int(segment_template.get("startNumber") or 1) start_number = int(segment_template.get("startNumber") or 1)
segment_timeline = segment_template.find("SegmentTimeline") segment_timeline = segment_template.find("SegmentTimeline")
segment_timescale = float(segment_template.get("timescale") or 1)
for item in ("initialization", "media"): for item in ("initialization", "media"):
value = segment_template.get(item) value = segment_template.get(item)
@ -318,17 +321,16 @@ class DASH:
track_kid = track.get_key_id(init_data) track_kid = track.get_key_id(init_data)
if segment_timeline is not None: if segment_timeline is not None:
seg_time_list = []
current_time = 0 current_time = 0
for s in segment_timeline.findall("S"): for s in segment_timeline.findall("S"):
if s.get("t"): if s.get("t"):
current_time = int(s.get("t")) current_time = int(s.get("t"))
for _ in range(1 + (int(s.get("r") or 0))): for _ in range(1 + (int(s.get("r") or 0))):
seg_time_list.append(current_time) segment_durations.append(current_time)
current_time += int(s.get("d")) current_time += int(s.get("d"))
seg_num_list = list(range(start_number, len(seg_time_list) + start_number)) seg_num_list = list(range(start_number, len(segment_durations) + start_number))
for t, n in zip(seg_time_list, seg_num_list): for t, n in zip(segment_durations, seg_num_list):
segments.append(( segments.append((
DASH.replace_fields( DASH.replace_fields(
segment_template.get("media"), segment_template.get("media"),
@ -342,8 +344,7 @@ class DASH:
if not period_duration: if not period_duration:
raise ValueError("Duration of the Period was unable to be determined.") raise ValueError("Duration of the Period was unable to be determined.")
period_duration = DASH.pt_to_sec(period_duration) period_duration = DASH.pt_to_sec(period_duration)
segment_duration = float(segment_template.get("duration")) segment_duration = float(segment_template.get("duration")) or 1
segment_timescale = float(segment_template.get("timescale") or 1)
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale)) total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
for s in range(start_number, start_number + total_segments): for s in range(start_number, start_number + total_segments):
@ -356,7 +357,11 @@ class DASH:
Time=s Time=s
), None ), None
)) ))
# TODO: Should we floor/ceil/round, or is int() ok?
segment_durations.append(int(segment_duration))
elif segment_list is not None: elif segment_list is not None:
segment_timescale = float(segment_list.get("timescale") or 1)
init_data = None init_data = None
initialization = segment_list.find("Initialization") initialization = segment_list.find("Initialization")
if initialization is not None: if initialization is not None:
@ -388,6 +393,7 @@ class DASH:
media_url, media_url,
segment_url.get("mediaRange") segment_url.get("mediaRange")
)) ))
segment_durations.append(int(segment_url.get("duration") or 1))
elif segment_base is not None: elif segment_base is not None:
media_range = None media_range = None
init_data = None init_data = None
@ -420,6 +426,10 @@ class DASH:
log.debug(track.url) log.debug(track.url)
sys.exit(1) sys.exit(1)
# TODO: Should we floor/ceil/round, or is int() ok?
track.data["dash"]["timescale"] = int(segment_timescale)
track.data["dash"]["segment_durations"] = segment_durations
if not track.drm and isinstance(track, (Video, Audio)): if not track.drm and isinstance(track, (Video, Audio)):
try: try:
track.drm = [Widevine.from_init_data(init_data)] track.drm = [Widevine.from_init_data(init_data)]

View File

@ -256,11 +256,15 @@ class HLS:
downloader = track.downloader downloader = track.downloader
urls: list[dict[str, Any]] = [] urls: list[dict[str, Any]] = []
segment_durations: list[int] = []
range_offset = 0 range_offset = 0
for segment in master.segments: for segment in master.segments:
if segment in unwanted_segments: if segment in unwanted_segments:
continue continue
segment_durations.append(int(segment.duration))
if segment.byterange: if segment.byterange:
if downloader.__name__ == "aria2c": if downloader.__name__ == "aria2c":
# aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader # aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader
@ -277,6 +281,8 @@ class HLS:
} if byte_range else {} } if byte_range else {}
}) })
track.data["hls"]["segment_durations"] = segment_durations
segment_save_dir = save_dir / "segments" segment_save_dir = save_dir / "segments"
for status_update in downloader( for status_update in downloader(

View File

@ -7,7 +7,7 @@ from enum import Enum
from functools import partial from functools import partial
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Iterable, Optional from typing import Any, Callable, Iterable, Optional, Union
import pycaption import pycaption
import requests import requests
@ -20,6 +20,7 @@ from subtitle_filter import Subtitles
from devine.core import binaries from devine.core import binaries
from devine.core.tracks.track import Track from devine.core.tracks.track import Track
from devine.core.utilities import try_ensure_utf8 from devine.core.utilities import try_ensure_utf8
from devine.core.utils.webvtt import merge_segmented_webvtt
class Subtitle(Track): class Subtitle(Track):
@ -202,6 +203,24 @@ class Subtitle(Track):
self.convert(Subtitle.Codec.TimedTextMarkupLang) self.convert(Subtitle.Codec.TimedTextMarkupLang)
elif self.codec == Subtitle.Codec.fVTT: elif self.codec == Subtitle.Codec.fVTT:
self.convert(Subtitle.Codec.WebVTT) self.convert(Subtitle.Codec.WebVTT)
elif self.codec == Subtitle.Codec.WebVTT:
text = self.path.read_text("utf8")
if self.descriptor == Track.Descriptor.DASH:
text = merge_segmented_webvtt(
text,
segment_durations=self.data["dash"]["segment_durations"],
timescale=self.data["dash"]["timescale"]
)
elif self.descriptor == Track.Descriptor.HLS:
text = merge_segmented_webvtt(
text,
segment_durations=self.data["hls"]["segment_durations"],
timescale=1 # ?
)
caption_set = pycaption.WebVTTReader().read(text)
Subtitle.merge_same_cues(caption_set)
subtitle_text = pycaption.WebVTTWriter().write(caption_set)
self.path.write_text(subtitle_text, encoding="utf8")
def convert(self, codec: Subtitle.Codec) -> Path: def convert(self, codec: Subtitle.Codec) -> Path:
""" """
@ -308,14 +327,7 @@ class Subtitle(Track):
caption_lists[language] = caption_list caption_lists[language] = caption_list
caption_set: pycaption.CaptionSet = pycaption.CaptionSet(caption_lists) caption_set: pycaption.CaptionSet = pycaption.CaptionSet(caption_lists)
elif codec == Subtitle.Codec.WebVTT: elif codec == Subtitle.Codec.WebVTT:
text = try_ensure_utf8(data).decode("utf8") text = Subtitle.space_webvtt_headers(data)
# Segmented VTT when merged may have the WEBVTT headers part of the next caption
# if they are not separated far enough from the previous caption, hence the \n\n
text = text. \
replace("WEBVTT", "\n\nWEBVTT"). \
replace("\r", ""). \
replace("\n\n\n", "\n \n\n"). \
replace("\n\n<", "\n<")
caption_set = pycaption.WebVTTReader().read(text) caption_set = pycaption.WebVTTReader().read(text)
else: else:
raise ValueError(f"Unknown Subtitle format \"{codec}\"...") raise ValueError(f"Unknown Subtitle format \"{codec}\"...")
@ -332,6 +344,27 @@ class Subtitle(Track):
return caption_set return caption_set
@staticmethod
def space_webvtt_headers(data: Union[str, bytes]):
"""
Space out the WEBVTT Headers from Captions.
Segmented VTT when merged may have the WEBVTT headers part of the next caption
as they were not separated far enough from the previous caption and ended up
being considered as caption text rather than the header for the next segment.
"""
if isinstance(data, bytes):
data = try_ensure_utf8(data).decode("utf8")
elif not isinstance(data, str):
raise ValueError(f"Expecting data to be a str, not {data!r}")
text = data.replace("WEBVTT", "\n\nWEBVTT").\
replace("\r", "").\
replace("\n\n\n", "\n \n\n").\
replace("\n\n<", "\n<")
return text
@staticmethod @staticmethod
def merge_same_cues(caption_set: pycaption.CaptionSet): def merge_same_cues(caption_set: pycaption.CaptionSet):
"""Merge captions with the same timecodes and text as one in-place.""" """Merge captions with the same timecodes and text as one in-place."""

191
devine/core/utils/webvtt.py Normal file
View File

@ -0,0 +1,191 @@
import re
import sys
import typing
from typing import Optional
from pycaption import Caption, CaptionList, CaptionNode, CaptionReadError, WebVTTReader, WebVTTWriter
class CaptionListExt(CaptionList):
@typing.no_type_check
def __init__(self, iterable=None, layout_info=None):
self.first_segment_mpegts = 0
super().__init__(iterable, layout_info)
class CaptionExt(Caption):
@typing.no_type_check
def __init__(self, start, end, nodes, style=None, layout_info=None, segment_index=0, mpegts=0, cue_time=0.0):
style = style or {}
self.segment_index: int = segment_index
self.mpegts: float = mpegts
self.cue_time: float = cue_time
super().__init__(start, end, nodes, style, layout_info)
class WebVTTReaderExt(WebVTTReader):
# HLS extension support <https://datatracker.ietf.org/doc/html/rfc8216#section-3.5>
RE_TIMESTAMP_MAP = re.compile(r"X-TIMESTAMP-MAP.*")
RE_MPEGTS = re.compile(r"MPEGTS:(\d+)")
RE_LOCAL = re.compile(r"LOCAL:((?:(\d{1,}):)?(\d{2}):(\d{2})\.(\d{3}))")
def _parse(self, lines: list[str]) -> CaptionList:
captions = CaptionListExt()
start = None
end = None
nodes: list[CaptionNode] = []
layout_info = None
found_timing = False
segment_index = -1
mpegts = 0
cue_time = 0.0
# The first segment MPEGTS is needed to calculate the rest. It is possible that
# the first segment contains no cue and is ignored by pycaption, this acts as a fallback.
captions.first_segment_mpegts = 0
for i, line in enumerate(lines):
if "-->" in line:
found_timing = True
timing_line = i
last_start_time = captions[-1].start if captions else 0
try:
start, end, layout_info = self._parse_timing_line(line, last_start_time)
except CaptionReadError as e:
new_msg = f"{e.args[0]} (line {timing_line})"
tb = sys.exc_info()[2]
raise type(e)(new_msg).with_traceback(tb) from None
elif "" == line:
if found_timing and nodes:
found_timing = False
caption = CaptionExt(
start,
end,
nodes,
layout_info=layout_info,
segment_index=segment_index,
mpegts=mpegts,
cue_time=cue_time,
)
captions.append(caption)
nodes = []
elif "WEBVTT" in line:
# Merged segmented VTT doesn't have index information, track manually.
segment_index += 1
mpegts = 0
cue_time = 0.0
elif m := self.RE_TIMESTAMP_MAP.match(line):
if r := self.RE_MPEGTS.search(m.group()):
mpegts = int(r.group(1))
cue_time = self._parse_local(m.group())
# Early assignment in case the first segment contains no cue.
if segment_index == 0:
captions.first_segment_mpegts = mpegts
else:
if found_timing:
if nodes:
nodes.append(CaptionNode.create_break())
nodes.append(CaptionNode.create_text(self._decode(line)))
else:
# it's a comment or some metadata; ignore it
pass
# Add a last caption if there are remaining nodes
if nodes:
caption = CaptionExt(start, end, nodes, layout_info=layout_info, segment_index=segment_index, mpegts=mpegts)
captions.append(caption)
return captions
@staticmethod
def _parse_local(string: str) -> float:
"""
Parse WebVTT LOCAL time and convert it to seconds.
"""
m = WebVTTReaderExt.RE_LOCAL.search(string)
if not m:
return 0
parsed = m.groups()
if not parsed:
return 0
hours = int(parsed[1])
minutes = int(parsed[2])
seconds = int(parsed[3])
milliseconds = int(parsed[4])
return (milliseconds / 1000) + seconds + (minutes * 60) + (hours * 3600)
def merge_segmented_webvtt(vtt_raw: str, segment_durations: Optional[list[int]] = None, timescale: int = 1) -> str:
"""
Merge Segmented WebVTT data.
Parameters:
vtt_raw: The concatenated WebVTT files to merge. All WebVTT headers must be
appropriately spaced apart, or it may produce unwanted effects like
considering headers as captions, timestamp lines, etc.
segment_durations: A list of each segment's duration. If not provided it will try
to get it from the X-TIMESTAMP-MAP headers, specifically the MPEGTS number.
timescale: The number of time units per second.
This parses the X-TIMESTAMP-MAP data to compute new absolute timestamps, replacing
the old start and end timestamp values. All X-TIMESTAMP-MAP header information will
be removed from the output as they are no longer of concern. Consider this function
the opposite of a WebVTT Segmenter, a WebVTT Joiner of sorts.
Algorithm borrowed from N_m3u8DL-RE and shaka-player.
"""
MPEG_TIMESCALE = 90_000
vtt = WebVTTReaderExt().read(vtt_raw)
for lang in vtt.get_languages():
prev_caption = None
duplicate_index: list[int] = []
captions = vtt.get_captions(lang)
if captions[0].segment_index == 0:
first_segment_mpegts = captions[0].mpegts
else:
first_segment_mpegts = segment_durations[0] if segment_durations else captions.first_segment_mpegts
caption: CaptionExt
for i, caption in enumerate(captions):
# DASH WebVTT doesn't have MPEGTS timestamp like HLS. Instead,
# calculate the timestamp from SegmentTemplate/SegmentList duration.
likely_dash = first_segment_mpegts == 0 and caption.mpegts == 0
if likely_dash and segment_durations:
duration = segment_durations[caption.segment_index]
caption.mpegts = MPEG_TIMESCALE * (duration / timescale)
if caption.mpegts == 0:
continue
seconds = (caption.mpegts - first_segment_mpegts) / MPEG_TIMESCALE - caption.cue_time
offset = seconds * 1_000_000 # pycaption use microseconds
if caption.start < offset:
caption.start += offset
caption.end += offset
# If the difference between current and previous captions is <=1ms
# and the payload is equal then splice.
if (
prev_caption
and not caption.is_empty()
and (caption.start - prev_caption.end) <= 1000 # 1ms in microseconds
and caption.get_text() == prev_caption.get_text()
):
prev_caption.end = caption.end
duplicate_index.append(i)
prev_caption = caption
# Remove duplicate
captions[:] = [c for c_index, c in enumerate(captions) if c_index not in set(duplicate_index)]
return WebVTTWriter().write(vtt)