Add support for SegmentBase and BaseURL-only DASH Manifests

This commit is contained in:
rlaphoenix 2024-02-05 10:22:40 +00:00
parent c06ea4cea8
commit 3b62b50e25
2 changed files with 225 additions and 202 deletions

View File

@ -856,8 +856,7 @@ class dl:
proxy=proxy, proxy=proxy,
license_widevine=prepare_drm license_widevine=prepare_drm
) )
# no else-if as DASH may convert the track to URL descriptor elif track.descriptor == track.Descriptor.URL:
if track.descriptor == track.Descriptor.URL:
try: try:
if not track.drm and isinstance(track, (Video, Audio)): if not track.drm and isinstance(track, (Video, Audio)):
# the service might not have explicitly defined the `drm` property # the service might not have explicitly defined the `drm` property

View File

@ -268,228 +268,252 @@ class DASH:
if segment_list is None: if segment_list is None:
segment_list = adaptation_set.find("SegmentList") segment_list = adaptation_set.find("SegmentList")
if segment_template is None and segment_list is None and rep_base_url: segment_base = representation.find("SegmentBase")
# If there's no SegmentTemplate and no SegmentList, then SegmentBase is used or just BaseURL if segment_base is None:
# Regardless which of the two is used, we can just directly grab the BaseURL segment_base = adaptation_set.find("SegmentBase")
# Players would normally calculate segments via Byte-Ranges, but we don't care
track.url = rep_base_url
track.descriptor = track.Descriptor.URL
else:
segments: list[tuple[str, Optional[str]]] = []
track_kid: Optional[UUID] = None
if segment_template is not None: segments: list[tuple[str, Optional[str]]] = []
segment_template = copy(segment_template) track_kid: Optional[UUID] = None
start_number = int(segment_template.get("startNumber") or 1)
segment_timeline = segment_template.find("SegmentTimeline")
for item in ("initialization", "media"): if segment_template is not None:
value = segment_template.get(item) segment_template = copy(segment_template)
if not value: start_number = int(segment_template.get("startNumber") or 1)
continue segment_timeline = segment_template.find("SegmentTimeline")
if not re.match("^https?://", value, re.IGNORECASE):
if not rep_base_url:
raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.")
value = urljoin(rep_base_url, value)
if not urlparse(value).query and manifest_url_query:
value += f"?{manifest_url_query}"
segment_template.set(item, value)
init_url = segment_template.get("initialization") for item in ("initialization", "media"):
if init_url: value = segment_template.get(item)
res = session.get(DASH.replace_fields( if not value:
init_url, continue
Bandwidth=representation.get("bandwidth"), if not re.match("^https?://", value, re.IGNORECASE):
RepresentationID=representation.get("id") if not rep_base_url:
)) raise ValueError("Resolved Segment URL is not absolute, and no Base URL is available.")
res.raise_for_status() value = urljoin(rep_base_url, value)
init_data = res.content if not urlparse(value).query and manifest_url_query:
track_kid = track.get_key_id(init_data) value += f"?{manifest_url_query}"
segment_template.set(item, value)
if segment_timeline is not None: init_url = segment_template.get("initialization")
seg_time_list = [] if init_url:
current_time = 0 res = session.get(DASH.replace_fields(
for s in segment_timeline.findall("S"): init_url,
if s.get("t"): Bandwidth=representation.get("bandwidth"),
current_time = int(s.get("t")) RepresentationID=representation.get("id")
for _ in range(1 + (int(s.get("r") or 0))): ))
seg_time_list.append(current_time) res.raise_for_status()
current_time += int(s.get("d")) init_data = res.content
seg_num_list = list(range(start_number, len(seg_time_list) + start_number)) track_kid = track.get_key_id(init_data)
for t, n in zip(seg_time_list, seg_num_list): if segment_timeline is not None:
segments.append(( seg_time_list = []
DASH.replace_fields( current_time = 0
segment_template.get("media"), for s in segment_timeline.findall("S"):
Bandwidth=representation.get("bandwidth"), if s.get("t"):
Number=n, current_time = int(s.get("t"))
RepresentationID=representation.get("id"), for _ in range(1 + (int(s.get("r") or 0))):
Time=t seg_time_list.append(current_time)
), None current_time += int(s.get("d"))
)) seg_num_list = list(range(start_number, len(seg_time_list) + start_number))
else:
if not period_duration:
raise ValueError("Duration of the Period was unable to be determined.")
period_duration = DASH.pt_to_sec(period_duration)
segment_duration = float(segment_template.get("duration"))
segment_timescale = float(segment_template.get("timescale") or 1)
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
for s in range(start_number, start_number + total_segments):
segments.append((
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=s,
RepresentationID=representation.get("id"),
Time=s
), None
))
elif segment_list is not None:
init_data = None
initialization = segment_list.find("Initialization")
if initialization is not None:
source_url = initialization.get("sourceURL")
if not source_url or not re.match("^https?://", source_url, re.IGNORECASE):
source_url = urljoin(rep_base_url, f"./{source_url}")
if initialization.get("range"):
init_range_header = {"Range": f"bytes={initialization.get('range')}"}
else:
init_range_header = None
res = session.get(url=source_url, headers=init_range_header)
res.raise_for_status()
init_data = res.content
track_kid = track.get_key_id(init_data)
segment_urls = segment_list.findall("SegmentURL")
for segment_url in segment_urls:
media_url = segment_url.get("media")
if not media_url or not re.match("^https?://", media_url, re.IGNORECASE):
media_url = urljoin(rep_base_url, f"./{media_url}")
for t, n in zip(seg_time_list, seg_num_list):
segments.append(( segments.append((
media_url, DASH.replace_fields(
segment_url.get("mediaRange") segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=n,
RepresentationID=representation.get("id"),
Time=t
), None
)) ))
else: else:
log.error("Could not find a way to get segments from this MPD manifest.") if not period_duration:
log.debug(manifest_url) raise ValueError("Duration of the Period was unable to be determined.")
sys.exit(1) period_duration = DASH.pt_to_sec(period_duration)
segment_duration = float(segment_template.get("duration"))
segment_timescale = float(segment_template.get("timescale") or 1)
total_segments = math.ceil(period_duration / (segment_duration / segment_timescale))
if not track.drm and isinstance(track, (Video, Audio)): for s in range(start_number, start_number + total_segments):
segments.append((
DASH.replace_fields(
segment_template.get("media"),
Bandwidth=representation.get("bandwidth"),
Number=s,
RepresentationID=representation.get("id"),
Time=s
), None
))
elif segment_list is not None:
init_data = None
initialization = segment_list.find("Initialization")
if initialization is not None:
source_url = initialization.get("sourceURL")
if not source_url or not re.match("^https?://", source_url, re.IGNORECASE):
source_url = urljoin(rep_base_url, f"./{source_url}")
if initialization.get("range"):
init_range_header = {"Range": f"bytes={initialization.get('range')}"}
else:
init_range_header = None
res = session.get(url=source_url, headers=init_range_header)
res.raise_for_status()
init_data = res.content
track_kid = track.get_key_id(init_data)
segment_urls = segment_list.findall("SegmentURL")
for segment_url in segment_urls:
media_url = segment_url.get("media")
if not media_url or not re.match("^https?://", media_url, re.IGNORECASE):
media_url = urljoin(rep_base_url, f"./{media_url}")
segments.append((
media_url,
segment_url.get("mediaRange")
))
elif segment_base is not None:
media_range = None
init_data = None
initialization = segment_base.find("Initialization")
if initialization is not None:
if initialization.get("range"):
init_range_header = {"Range": f"bytes={initialization.get('range')}"}
else:
init_range_header = None
res = session.get(url=rep_base_url, headers=init_range_header)
res.raise_for_status()
init_data = res.content
track_kid = track.get_key_id(init_data)
total_size = res.headers.get("Content-Range", "").split("/")[-1]
if total_size:
media_range = f"{len(init_data)}-{total_size}"
segments.append((
rep_base_url,
media_range
))
elif rep_base_url:
segments.append((
rep_base_url,
None
))
else:
log.error("Could not find a way to get segments from this MPD manifest.")
log.debug(manifest_url)
sys.exit(1)
if not track.drm and isinstance(track, (Video, Audio)):
try:
track.drm = [Widevine.from_init_data(init_data)]
except Widevine.Exceptions.PSSHNotFound:
# it might not have Widevine DRM, or might not have found the PSSH
log.warning("No Widevine PSSH was found for this track, is it DRM free?")
if track.drm:
# last chance to find the KID, assumes first segment will hold the init data
track_kid = track_kid or track.get_key_id(url=segments[0][0], session=session)
# TODO: What if we don't want to use the first DRM system?
drm = track.drm[0]
if isinstance(drm, Widevine):
# license and grab content keys
try: try:
track.drm = [Widevine.from_init_data(init_data)] if not license_widevine:
except Widevine.Exceptions.PSSHNotFound: raise ValueError("license_widevine func must be supplied to use Widevine DRM")
# it might not have Widevine DRM, or might not have found the PSSH progress(downloaded="LICENSING")
log.warning("No Widevine PSSH was found for this track, is it DRM free?") license_widevine(drm, track_kid=track_kid)
progress(downloaded="[yellow]LICENSED")
except Exception: # noqa
DOWNLOAD_CANCELLED.set() # skip pending track downloads
progress(downloaded="[red]FAILED")
raise
else:
drm = None
if track.drm: if DOWNLOAD_LICENCE_ONLY.is_set():
# last chance to find the KID, assumes first segment will hold the init data progress(downloaded="[yellow]SKIPPED")
track_kid = track_kid or track.get_key_id(url=segments[0][0], session=session) return
# TODO: What if we don't want to use the first DRM system?
drm = track.drm[0]
if isinstance(drm, Widevine):
# license and grab content keys
try:
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
progress(downloaded="LICENSING")
license_widevine(drm, track_kid=track_kid)
progress(downloaded="[yellow]LICENSED")
except Exception: # noqa
DOWNLOAD_CANCELLED.set() # skip pending track downloads
progress(downloaded="[red]FAILED")
raise
else:
drm = None
if DOWNLOAD_LICENCE_ONLY.is_set(): progress(total=len(segments))
progress(downloaded="[yellow]SKIPPED")
return
progress(total=len(segments)) download_sizes = []
download_speed_window = 5
last_speed_refresh = time.time()
download_sizes = [] with ThreadPoolExecutor(max_workers=16) as pool:
download_speed_window = 5 for i, download in enumerate(futures.as_completed((
last_speed_refresh = time.time() pool.submit(
DASH.download_segment,
url=url,
out_path=(save_dir / str(n).zfill(len(str(len(segments))))).with_suffix(".mp4"),
track=track,
proxy=proxy,
headers=session.headers,
cookies=session.cookies,
bytes_range=bytes_range
)
for n, (url, bytes_range) in enumerate(segments)
))):
try:
download_size = download.result()
except KeyboardInterrupt:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
progress(downloaded="[yellow]CANCELLING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[yellow]CANCELLED")
# tell dl that it was cancelled
# the pool is already shut down, so exiting loop is fine
raise
except Exception:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
progress(downloaded="[red]FAILING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[red]FAILED")
# tell dl that it failed
# the pool is already shut down, so exiting loop is fine
raise
else:
progress(advance=1)
with ThreadPoolExecutor(max_workers=16) as pool: now = time.time()
for i, download in enumerate(futures.as_completed(( time_since = now - last_speed_refresh
pool.submit(
DASH.download_segment,
url=url,
out_path=(save_dir / str(n).zfill(len(str(len(segments))))).with_suffix(".mp4"),
track=track,
proxy=proxy,
headers=session.headers,
cookies=session.cookies,
bytes_range=bytes_range
)
for n, (url, bytes_range) in enumerate(segments)
))):
try:
download_size = download.result()
except KeyboardInterrupt:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
progress(downloaded="[yellow]CANCELLING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[yellow]CANCELLED")
# tell dl that it was cancelled
# the pool is already shut down, so exiting loop is fine
raise
except Exception:
DOWNLOAD_CANCELLED.set() # skip pending track downloads
progress(downloaded="[red]FAILING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[red]FAILED")
# tell dl that it failed
# the pool is already shut down, so exiting loop is fine
raise
else:
progress(advance=1)
now = time.time() if download_size: # no size == skipped dl
time_since = now - last_speed_refresh download_sizes.append(download_size)
if download_size: # no size == skipped dl if download_sizes and (time_since > download_speed_window or i == len(segments)):
download_sizes.append(download_size) data_size = sum(download_sizes)
download_speed = data_size / (time_since or 1)
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
if download_sizes and (time_since > download_speed_window or i == len(segments)): with open(save_path, "wb") as f:
data_size = sum(download_sizes) if init_data:
download_speed = data_size / (time_since or 1) f.write(init_data)
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s") for segment_file in sorted(save_dir.iterdir()):
last_speed_refresh = now segment_data = segment_file.read_bytes()
download_sizes.clear() # TODO: fix encoding after decryption?
if (
not drm and isinstance(track, Subtitle) and
track.codec not in (Subtitle.Codec.fVTT, Subtitle.Codec.fTTML)
):
segment_data = try_ensure_utf8(segment_data)
segment_data = html.unescape(segment_data.decode("utf8")).encode("utf8")
f.write(segment_data)
segment_file.unlink()
with open(save_path, "wb") as f: if drm:
if init_data: progress(downloaded="Decrypting", completed=0, total=100)
f.write(init_data) drm.decrypt(save_path)
for segment_file in sorted(save_dir.iterdir()): track.drm = None
segment_data = segment_file.read_bytes() if callable(track.OnDecrypted):
# TODO: fix encoding after decryption? track.OnDecrypted(track)
if ( progress(downloaded="Decrypted", completed=100)
not drm and isinstance(track, Subtitle) and
track.codec not in (Subtitle.Codec.fVTT, Subtitle.Codec.fTTML)
):
segment_data = try_ensure_utf8(segment_data)
segment_data = html.unescape(segment_data.decode("utf8")).encode("utf8")
f.write(segment_data)
segment_file.unlink()
if drm: track.path = save_path
progress(downloaded="Decrypting", completed=0, total=100) save_dir.rmdir()
drm.decrypt(save_path)
track.drm = None
if callable(track.OnDecrypted):
track.OnDecrypted(track)
progress(downloaded="Decrypted", completed=100)
track.path = save_path progress(downloaded="Downloaded")
save_dir.rmdir()
progress(downloaded="Downloaded")
@staticmethod @staticmethod
def download_segment( def download_segment(