Alter behaviour of --skip-dl to allow DRM licensing

Most people used --skip-dl just to license the DRM pre-v1.3.0. Which makes sense, --skip-dl is otherwise a pointless feature. I've fixed it so that --skip-dl worked like before, allowing license calls, while still supporting the new per-segment features post-v1.3.0.

Fixes #37
This commit is contained in:
rlaphoenix 2023-05-11 22:17:41 +01:00
parent 3ec317e9d6
commit b92708ef45
3 changed files with 125 additions and 100 deletions

View File

@ -134,6 +134,7 @@ class dl:
return dl(ctx, **kwargs) return dl(ctx, **kwargs)
DL_POOL_STOP = Event() DL_POOL_STOP = Event()
DL_POOL_SKIP = Event()
DRM_TABLE_LOCK = Lock() DRM_TABLE_LOCK = Lock()
def __init__( def __init__(
@ -458,75 +459,78 @@ class dl:
download_table = Table.grid() download_table = Table.grid()
download_table.add_row(selected_tracks) download_table.add_row(selected_tracks)
dl_start_time = time.time()
if skip_dl: if skip_dl:
self.log.info("Skipping Download...") self.DL_POOL_SKIP.set()
else:
dl_start_time = time.time()
try: try:
with Live( with Live(
Padding( Padding(
download_table, download_table,
(1, 5)
),
console=console,
refresh_per_second=5
):
with ThreadPoolExecutor(workers) as pool:
for download in futures.as_completed((
pool.submit(
self.download_track,
service=service,
track=track,
prepare_drm=partial(
partial(
self.prepare_drm,
table=download_table
),
track=track,
title=title,
certificate=partial(
service.get_widevine_service_certificate,
title=title,
track=track
),
licence=partial(
service.get_widevine_license,
title=title,
track=track
),
cdm_only=cdm_only,
vaults_only=vaults_only,
export=export
),
progress=tracks_progress_callables[i]
)
for i, track in enumerate(title.tracks)
)):
download.result()
except KeyboardInterrupt:
console.print(Padding(
":x: Download Cancelled...",
(0, 5, 1, 5)
))
return
except Exception as e: # noqa
error_messages = [
":x: Download Failed...",
" One of the download workers had an error!",
" See the error trace above for more information."
]
if isinstance(e, subprocess.CalledProcessError):
# ignore process exceptions as proper error logs are already shown
error_messages.append(f" Process exit code: {e.returncode}")
else:
console.print_exception()
console.print(Padding(
Group(*error_messages),
(1, 5) (1, 5)
)) ),
return console=console,
refresh_per_second=5
):
with ThreadPoolExecutor(workers) as pool:
for download in futures.as_completed((
pool.submit(
self.download_track,
service=service,
track=track,
prepare_drm=partial(
partial(
self.prepare_drm,
table=download_table
),
track=track,
title=title,
certificate=partial(
service.get_widevine_service_certificate,
title=title,
track=track
),
licence=partial(
service.get_widevine_license,
title=title,
track=track
),
cdm_only=cdm_only,
vaults_only=vaults_only,
export=export
),
progress=tracks_progress_callables[i]
)
for i, track in enumerate(title.tracks)
)):
download.result()
except KeyboardInterrupt:
console.print(Padding(
":x: Download Cancelled...",
(0, 5, 1, 5)
))
return
except Exception as e: # noqa
error_messages = [
":x: Download Failed...",
" One of the download workers had an error!",
" See the error trace above for more information."
]
if isinstance(e, subprocess.CalledProcessError):
# ignore process exceptions as proper error logs are already shown
error_messages.append(f" Process exit code: {e.returncode}")
else:
console.print_exception()
console.print(Padding(
Group(*error_messages),
(1, 5)
))
return
if skip_dl:
console.log("Skipped downloads as --skip-dl was used...")
else:
dl_time = time_elapsed_since(dl_start_time) dl_time = time_elapsed_since(dl_start_time)
console.print(Padding( console.print(Padding(
f"Track downloads finished in [progress.elapsed]{dl_time}[/]", f"Track downloads finished in [progress.elapsed]{dl_time}[/]",
@ -806,6 +810,9 @@ class dl:
prepare_drm: Callable, prepare_drm: Callable,
progress: partial progress: partial
): ):
if self.DL_POOL_SKIP.is_set():
progress(downloaded="[yellow]SKIPPING")
if self.DL_POOL_STOP.is_set(): if self.DL_POOL_STOP.is_set():
progress(downloaded="[yellow]SKIPPED") progress(downloaded="[yellow]SKIPPED")
return return
@ -815,12 +822,6 @@ class dl:
else: else:
proxy = None proxy = None
if config.directories.temp.is_file():
self.log.error(f"Temp Directory '{config.directories.temp}' must be a Directory, not a file")
sys.exit(1)
config.directories.temp.mkdir(parents=True, exist_ok=True)
save_path = config.directories.temp / f"{track.__class__.__name__}_{track.id}.mp4" save_path = config.directories.temp / f"{track.__class__.__name__}_{track.id}.mp4"
if isinstance(track, Subtitle): if isinstance(track, Subtitle):
save_path = save_path.with_suffix(f".{track.codec.extension}") save_path = save_path.with_suffix(f".{track.codec.extension}")
@ -841,11 +842,18 @@ class dl:
if save_dir.exists() and save_dir.name.endswith("_segments"): if save_dir.exists() and save_dir.name.endswith("_segments"):
shutil.rmtree(save_dir) shutil.rmtree(save_dir)
# Delete any pre-existing temp files matching this track. if not self.DL_POOL_SKIP.is_set():
# We can't re-use or continue downloading these tracks as they do not use a if config.directories.temp.is_file():
# lock file. Or at least the majority don't. Even if they did I've encountered self.log.error(f"Temp Directory '{config.directories.temp}' must be a Directory, not a file")
# corruptions caused by sudden interruptions to the lock file. sys.exit(1)
cleanup()
config.directories.temp.mkdir(parents=True, exist_ok=True)
# Delete any pre-existing temp files matching this track.
# We can't re-use or continue downloading these tracks as they do not use a
# lock file. Or at least the majority don't. Even if they did I've encountered
# corruptions caused by sudden interruptions to the lock file.
cleanup()
try: try:
if track.descriptor == track.Descriptor.M3U: if track.descriptor == track.Descriptor.M3U:
@ -854,6 +862,7 @@ class dl:
save_path=save_path, save_path=save_path,
save_dir=save_dir, save_dir=save_dir,
stop_event=self.DL_POOL_STOP, stop_event=self.DL_POOL_STOP,
skip_event=self.DL_POOL_SKIP,
progress=progress, progress=progress,
session=service.session, session=service.session,
proxy=proxy, proxy=proxy,
@ -865,6 +874,7 @@ class dl:
save_path=save_path, save_path=save_path,
save_dir=save_dir, save_dir=save_dir,
stop_event=self.DL_POOL_STOP, stop_event=self.DL_POOL_STOP,
skip_event=self.DL_POOL_SKIP,
progress=progress, progress=progress,
session=service.session, session=service.session,
proxy=proxy, proxy=proxy,
@ -893,21 +903,24 @@ class dl:
else: else:
drm = None drm = None
asyncio.run(aria2c( if self.DL_POOL_SKIP.is_set():
uri=track.url, progress(downloaded="[yellow]SKIPPED")
out=save_path, else:
headers=service.session.headers, asyncio.run(aria2c(
proxy=proxy if track.needs_proxy else None, uri=track.url,
progress=progress out=save_path,
)) headers=service.session.headers,
proxy=proxy if track.needs_proxy else None,
progress=progress
))
track.path = save_path track.path = save_path
if drm: if drm:
drm.decrypt(save_path) drm.decrypt(save_path)
track.drm = None track.drm = None
if callable(track.OnDecrypted): if callable(track.OnDecrypted):
track.OnDecrypted(track) track.OnDecrypted(track)
except KeyboardInterrupt: except KeyboardInterrupt:
self.DL_POOL_STOP.set() self.DL_POOL_STOP.set()
progress(downloaded="[yellow]STOPPED") progress(downloaded="[yellow]STOPPED")
@ -917,25 +930,27 @@ class dl:
progress(downloaded="[red]FAILED") progress(downloaded="[red]FAILED")
raise raise
except (Exception, KeyboardInterrupt): except (Exception, KeyboardInterrupt):
cleanup() if not self.DL_POOL_SKIP.is_set():
cleanup()
raise raise
if self.DL_POOL_STOP.is_set(): if self.DL_POOL_STOP.is_set():
# we stopped during the download, let's exit # we stopped during the download, let's exit
return return
if track.path.stat().st_size <= 3: # Empty UTF-8 BOM == 3 bytes if not self.DL_POOL_SKIP.is_set():
raise IOError( if track.path.stat().st_size <= 3: # Empty UTF-8 BOM == 3 bytes
"Download failed, the downloaded file is empty. " raise IOError(
f"This {'was' if track.needs_proxy else 'was not'} downloaded with a proxy." + "Download failed, the downloaded file is empty. "
( f"This {'was' if track.needs_proxy else 'was not'} downloaded with a proxy." +
" Perhaps you need to set `needs_proxy` as True to use the proxy for this track." (
if not track.needs_proxy else "" " Perhaps you need to set `needs_proxy` as True to use the proxy for this track."
if not track.needs_proxy else ""
)
) )
)
if callable(track.OnDownloaded): if callable(track.OnDownloaded):
track.OnDownloaded(track) track.OnDownloaded(track)
@staticmethod @staticmethod
def get_profile(service: str) -> Optional[str]: def get_profile(service: str) -> Optional[str]:

View File

@ -283,6 +283,7 @@ class DASH:
save_path: Path, save_path: Path,
save_dir: Path, save_dir: Path,
stop_event: Event, stop_event: Event,
skip_event: Event,
progress: partial, progress: partial,
session: Optional[Session] = None, session: Optional[Session] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
@ -458,6 +459,10 @@ class DASH:
else: else:
drm = None drm = None
if skip_event.is_set():
progress(downloaded="[yellow]SKIPPED")
return
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int: def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
if stop_event.is_set(): if stop_event.is_set():
# the track already started downloading, but another failed or was stopped # the track already started downloading, but another failed or was stopped

View File

@ -184,6 +184,7 @@ class HLS:
save_path: Path, save_path: Path,
save_dir: Path, save_dir: Path,
stop_event: Event, stop_event: Event,
skip_event: Event,
progress: partial, progress: partial,
session: Optional[Session] = None, session: Optional[Session] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
@ -280,6 +281,10 @@ class HLS:
finally: finally:
segment_key.put(newest_segment_key) segment_key.put(newest_segment_key)
if skip_event.is_set():
progress(downloaded="[yellow]SKIPPING")
return 0
if not segment.uri.startswith(segment.base_uri): if not segment.uri.startswith(segment.base_uri):
segment.uri = segment.base_uri + segment.uri segment.uri = segment.base_uri + segment.uri