Move download_segment() from DASH/HLS download_track() to Class

Various overall small readability improvements have also been made.
This commit is contained in:
rlaphoenix 2023-05-17 03:12:29 +01:00
parent 03c012f88e
commit dd64212ad2
2 changed files with 288 additions and 216 deletions

View File

@ -13,7 +13,7 @@ from functools import partial
from hashlib import md5
from pathlib import Path
from threading import Event
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Optional, Union, MutableMapping
from urllib.parse import urljoin, urlparse
from uuid import UUID
@ -392,7 +392,8 @@ class DASH:
# last chance to find the KID, assumes first segment will hold the init data
track_kid = track_kid or track.get_key_id(url=segments[0][0], session=session)
# license and grab content keys
drm = track.drm[0] # just use the first supported DRM system for now
# TODO: What if we don't want to use the first DRM system?
drm = track.drm[0]
if isinstance(drm, Widevine):
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
@ -404,74 +405,26 @@ class DASH:
progress(downloaded="[yellow]SKIPPED")
return
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
if stop_event.is_set():
# the track already started downloading, but another failed or was stopped
raise KeyboardInterrupt()
segment_save_path = (save_dir / filename).with_suffix(".mp4")
segment_uri, segment_range = segment
attempts = 1
while True:
try:
downloader_ = downloader
headers_ = session.headers
if segment_range:
# aria2(c) doesn't support byte ranges, let's use python-requests (likely slower)
downloader_ = requests_downloader
headers_["Range"] = f"bytes={segment_range}"
downloader_(
uri=segment_uri,
out=segment_save_path,
headers=headers_,
proxy=proxy,
silent=attempts != 5,
segmented=True
)
break
except Exception as ee:
if stop_event.is_set() or attempts == 5:
raise ee
time.sleep(2)
attempts += 1
data_size = segment_save_path.stat().st_size
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Should this be done in the video data or the init data?
if isinstance(track, Audio):
with open(segment_save_path, "rb+") as f:
segment_data = f.read()
fixed_segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
if fixed_segment_data != segment_data:
f.seek(0)
f.write(fixed_segment_data)
return data_size
progress(total=len(segments))
finished_threads = 0
download_sizes = []
download_speed_window = 5
last_speed_refresh = time.time()
with ThreadPoolExecutor(max_workers=16) as pool:
for download in futures.as_completed((
for i, download in enumerate(futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(segments)))),
segment=segment
DASH.download_segment,
url=url,
out_path=(save_dir / str(n).zfill(len(str(len(segments))))).with_suffix(".mp4"),
track=track,
proxy=proxy,
headers=session.headers,
bytes_range=bytes_range,
stop_event=stop_event
)
for i, segment in enumerate(segments)
)):
finished_threads += 1
for n, (url, bytes_range) in enumerate(segments)
))):
try:
download_size = download.result()
except KeyboardInterrupt:
@ -482,16 +435,15 @@ class DASH:
# tell dl that it was cancelled
# the pool is already shut down, so exiting loop is fine
raise
except Exception as e:
except Exception:
stop_event.set() # skip pending track downloads
progress(downloaded="[red]FAILING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[red]FAILED")
# tell dl that it failed
# the pool is already shut down, so exiting loop is fine
raise e
raise
else:
# it successfully downloaded, and it was not cancelled
progress(advance=1)
now = time.time()
@ -500,7 +452,7 @@ class DASH:
if download_size: # no size == skipped dl
download_sizes.append(download_size)
if download_sizes and (time_since > 5 or finished_threads == len(segments)):
if download_sizes and (time_since > download_speed_window or i == len(segments)):
data_size = sum(download_sizes)
download_speed = data_size / (time_since or 1)
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
@ -527,6 +479,76 @@ class DASH:
progress(downloaded="Downloaded")
@staticmethod
def download_segment(
url: str,
out_path: Path,
track: AnyTrack,
proxy: Optional[str] = None,
headers: Optional[MutableMapping[str, str | bytes]] = None,
bytes_range: Optional[str] = None,
stop_event: Optional[Event] = None
) -> int:
"""
Download a DASH Media Segment.
Parameters:
url: Full HTTP(S) URL to the Segment you want to download.
out_path: Path to save the downloaded Segment file to.
track: The Track object of which this Segment is for. Currently only used to
fix an invalid value in the TFHD box of Audio Tracks.
proxy: Proxy URI to use when downloading the Segment file.
headers: HTTP Headers to send when requesting the Segment file.
bytes_range: Download only specific bytes of the Segment file using the Range header.
stop_event: Prematurely stop the Download from beginning. Useful if ran from
a Thread Pool. It will raise a KeyboardInterrupt if set.
Returns the file size of the downloaded Segment in bytes.
"""
if stop_event and stop_event.is_set():
raise KeyboardInterrupt()
attempts = 1
while True:
try:
headers_ = headers or {}
if bytes_range:
# aria2(c) doesn't support byte ranges, use python-requests
downloader_ = requests_downloader
headers_["Range"] = f"bytes={bytes_range}"
else:
downloader_ = downloader
downloader_(
uri=url,
out=out_path,
headers=headers_,
proxy=proxy,
silent=attempts != 5,
segmented=True
)
break
except Exception as ee:
if (stop_event and stop_event.is_set()) or attempts == 5:
raise ee
time.sleep(2)
attempts += 1
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Should this be done in the video data or the init data?
if isinstance(track, Audio):
with open(out_path, "rb+") as f:
segment_data = f.read()
fixed_segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
if fixed_segment_data != segment_data:
f.seek(0)
f.write(fixed_segment_data)
return out_path.stat().st_size
@staticmethod
def _get(
item: str,

View File

@ -214,137 +214,6 @@ class HLS:
log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
sys.exit(1)
drm_lock = Lock()
def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int:
if stop_event.is_set():
# the track already started downloading, but another failed or was stopped
raise KeyboardInterrupt()
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4")
newest_init_data = init_data.get()
try:
if segment.init_section and (not newest_init_data or segment.discontinuity):
# Only use the init data if there's no init data yet (e.g., start of file)
# or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
# Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
# be unnecessary and slow to re-download the init data each time.
if not segment.init_section.uri.startswith(segment.init_section.base_uri):
segment.init_section.uri = segment.init_section.base_uri + segment.init_section.uri
if segment.init_section.byterange:
byte_range = HLS.calculate_byte_range(segment.init_section.byterange)
_ = range_offset.get()
range_offset.put(byte_range.split("-")[0])
headers = {
"Range": f"bytes={byte_range}"
}
else:
headers = {}
log.debug("Got new init segment, %s", segment.init_section.uri)
res = session.get(segment.init_section.uri, headers=headers)
res.raise_for_status()
newest_init_data = res.content
finally:
init_data.put(newest_init_data)
with drm_lock:
newest_segment_key = segment_key.get()
try:
if segment.keys and newest_segment_key[1] != segment.keys:
try:
drm = HLS.get_drm(
keys=segment.keys,
proxy=proxy
)
except NotImplementedError as e:
log.error(str(e))
sys.exit(1)
else:
if drm:
track.drm = drm
drm = drm[0] # just use the first supported DRM system for now
log.debug("Got segment key, %s", drm)
if isinstance(drm, Widevine):
# license and grab content keys
track_kid = track.get_key_id(newest_init_data)
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm, track_kid=track_kid)
newest_segment_key = (drm, segment.keys)
finally:
segment_key.put(newest_segment_key)
if skip_event.is_set():
progress(downloaded="[yellow]SKIPPING")
return 0
if not segment.uri.startswith(segment.base_uri):
segment.uri = segment.base_uri + segment.uri
attempts = 1
while True:
try:
downloader_ = downloader
headers_ = session.headers
if segment.byterange:
# aria2(c) doesn't support byte ranges, let's use python-requests (likely slower)
previous_range_offset = range_offset.get()
byte_range = HLS.calculate_byte_range(segment.byterange, previous_range_offset)
range_offset.put(byte_range.split("-")[0])
downloader_ = requests_downloader
headers_["Range"] = f"bytes={byte_range}"
downloader_(
uri=segment.uri,
out=segment_save_path,
headers=headers_,
proxy=proxy,
silent=attempts != 5,
segmented=True
)
break
except Exception as ee:
if stop_event.is_set() or attempts == 5:
raise ee
time.sleep(2)
attempts += 1
data_size = segment_save_path.stat().st_size
if isinstance(track, Audio) or newest_init_data:
with open(segment_save_path, "rb+") as f:
segment_data = f.read()
if isinstance(track, Audio):
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Is this in mpeg data, or init data?
segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
# prepend the init data to be able to decrypt
if newest_init_data:
f.seek(0)
f.write(newest_init_data)
f.write(segment_data)
if newest_segment_key[0]:
newest_segment_key[0].decrypt(segment_save_path)
track.drm = None
if callable(track.OnDecrypted):
track.OnDecrypted(track)
return data_size
segment_key = Queue(maxsize=1)
init_data = Queue(maxsize=1)
range_offset = Queue(maxsize=1)
if track.drm:
session_drm = track.drm[0] # just use the first supported DRM system for now
if isinstance(session_drm, Widevine):
@ -355,30 +224,39 @@ class HLS:
else:
session_drm = None
# have data to begin with, or it will be stuck waiting on the first pool forever
segment_key.put((session_drm, None))
init_data.put(None)
range_offset.put(0)
progress(total=len(master.segments))
finished_threads = 0
download_sizes = []
download_speed_window = 5
last_speed_refresh = time.time()
with ThreadPoolExecutor(max_workers=16) as pool:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(master.segments)))),
segment=segment,
init_data=init_data,
segment_key=segment_key
)
for i, segment in enumerate(master.segments)
)):
finished_threads += 1
segment_key = Queue(maxsize=1)
segment_key.put((session_drm, None))
init_data = Queue(maxsize=1)
init_data.put(None)
range_offset = Queue(maxsize=1)
range_offset.put(0)
drm_lock = Lock()
with ThreadPoolExecutor(max_workers=16) as pool:
for i, download in enumerate(futures.as_completed((
pool.submit(
HLS.download_segment,
segment=segment,
out_path=(save_dir / str(n).zfill(len(str(len(master.segments))))).with_suffix(".mp4"),
track=track,
init_data=init_data,
segment_key=segment_key,
range_offset=range_offset,
drm_lock=drm_lock,
license_widevine=license_widevine,
session=session,
proxy=proxy,
stop_event=stop_event,
skip_event=skip_event
)
for n, segment in enumerate(master.segments)
))):
try:
download_size = download.result()
except KeyboardInterrupt:
@ -401,13 +279,17 @@ class HLS:
# it successfully downloaded, and it was not cancelled
progress(advance=1)
if download_size == -1: # skipped for --skip-dl
progress(downloaded="[yellow]SKIPPING")
continue
now = time.time()
time_since = now - last_speed_refresh
if download_size: # no size == skipped dl
download_sizes.append(download_size)
if download_sizes and (time_since > 5 or finished_threads == len(master.segments)):
if download_sizes and (time_since > download_speed_window or i == len(master.segments)):
data_size = sum(download_sizes)
download_speed = data_size / (time_since or 1)
progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
@ -424,6 +306,174 @@ class HLS:
track.path = save_path
save_dir.rmdir()
@staticmethod
def download_segment(
segment: m3u8.Segment,
out_path: Path,
track: AnyTrack,
init_data: Queue,
segment_key: Queue,
range_offset: Queue,
drm_lock: Lock,
license_widevine: Optional[Callable] = None,
session: Optional[Session] = None,
proxy: Optional[str] = None,
stop_event: Optional[Event] = None,
skip_event: Optional[Event] = None
) -> int:
"""
Download (and Decrypt) an HLS Media Segment.
Note: Make sure all Queue objects passed are appropriately initialized with
a starting value or this function may get permanently stuck.
Parameters:
segment: The m3u8.Segment Object to Download.
out_path: Path to save the downloaded Segment file to.
track: The Track object of which this Segment is for. Currently used to fix an
invalid value in the TFHD box of Audio Tracks, for the OnSegmentFilter, and
for DRM-related operations like getting the Track ID and Decryption.
init_data: Queue for saving and loading the most recent init section data.
segment_key: Queue for saving and loading the most recent DRM object, and it's
adjacent Segment.Key object.
range_offset: Queue for saving and loading the most recent Segment Bytes Range.
drm_lock: Prevent more than one Download from doing anything DRM-related at the
same time. Make sure all calls to download_segment() use the same Lock object.
license_widevine: Function used to license Widevine DRM objects. It must be passed
if the Segment's DRM uses Widevine.
proxy: Proxy URI to use when downloading the Segment file.
session: Python-Requests Session used when requesting init data.
stop_event: Prematurely stop the Download from beginning. Useful if ran from
a Thread Pool. It will raise a KeyboardInterrupt if set.
skip_event: Prematurely stop the Download from beginning. It returns with a
file size of -1 directly after DRM licensing occurs, even if it's DRM-free.
This is mainly for `--skip-dl` to allow licensing without downloading.
Returns the file size of the downloaded Segment in bytes.
"""
if stop_event.is_set():
raise KeyboardInterrupt()
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
return 0
# handle init section changes
newest_init_data = init_data.get()
try:
if segment.init_section and (not newest_init_data or segment.discontinuity):
# Only use the init data if there's no init data yet (e.g., start of file)
# or if EXT-X-DISCONTINUITY is reached at the same time as EXT-X-MAP.
# Even if a new EXT-X-MAP is supplied, it may just be duplicate and would
# be unnecessary and slow to re-download the init data each time.
if not segment.init_section.uri.startswith(segment.init_section.base_uri):
segment.init_section.uri = segment.init_section.base_uri + segment.init_section.uri
if segment.init_section.byterange:
byte_range = HLS.calculate_byte_range(segment.init_section.byterange)
_ = range_offset.get()
range_offset.put(byte_range.split("-")[0])
range_header = {
"Range": f"bytes={byte_range}"
}
else:
range_header = {}
res = session.get(segment.init_section.uri, headers=range_header)
res.raise_for_status()
newest_init_data = res.content
finally:
init_data.put(newest_init_data)
# handle segment key changes
with drm_lock:
newest_segment_key = segment_key.get()
try:
if segment.keys and newest_segment_key[1] != segment.keys:
drm = HLS.get_drm(
keys=segment.keys,
proxy=proxy
)
if drm:
track.drm = drm
# license and grab content keys
# TODO: What if we don't want to use the first DRM system?
drm = drm[0]
if isinstance(drm, Widevine):
track_kid = track.get_key_id(newest_init_data)
if not license_widevine:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm, track_kid=track_kid)
newest_segment_key = (drm, segment.keys)
finally:
segment_key.put(newest_segment_key)
if skip_event.is_set():
return -1
if not segment.uri.startswith(segment.base_uri):
segment.uri = segment.base_uri + segment.uri
attempts = 1
while True:
try:
headers_ = session.headers
if segment.byterange:
# aria2(c) doesn't support byte ranges, use python-requests
downloader_ = requests_downloader
previous_range_offset = range_offset.get()
byte_range = HLS.calculate_byte_range(segment.byterange, previous_range_offset)
range_offset.put(byte_range.split("-")[0])
headers_["Range"] = f"bytes={byte_range}"
else:
downloader_ = downloader
downloader_(
uri=segment.uri,
out=out_path,
headers=headers_,
proxy=proxy,
silent=attempts != 5,
segmented=True
)
break
except Exception as ee:
if stop_event.is_set() or attempts == 5:
raise ee
time.sleep(2)
attempts += 1
download_size = out_path.stat().st_size
# fix audio decryption on ATVP by fixing the sample description index
# TODO: Should this be done in the video data or the init data?
if isinstance(track, Audio):
with open(out_path, "rb+") as f:
segment_data = f.read()
fixed_segment_data = re.sub(
b"(tfhd\x00\x02\x00\x1a\x00\x00\x00\x01\x00\x00\x00)\x02",
b"\\g<1>\x01",
segment_data
)
if fixed_segment_data != segment_data:
f.seek(0)
f.write(fixed_segment_data)
# prepend the init data to be able to decrypt
if newest_init_data:
with open(out_path, "rb+") as f:
segment_data = f.read()
f.seek(0)
f.write(newest_init_data)
f.write(segment_data)
# decrypt segment if encrypted
if newest_segment_key[0]:
newest_segment_key[0].decrypt(out_path)
track.drm = None
if callable(track.OnDecrypted):
track.OnDecrypted(track)
return download_size
@staticmethod
def get_drm(
keys: list[Union[m3u8.model.SessionKey, m3u8.model.Key]],