From 383e7d9647720c5ea8b1c039ee094ee02e9f651d Mon Sep 17 00:00:00 2001 From: rlaphoenix Date: Tue, 28 Feb 2023 16:54:52 +0000 Subject: [PATCH] Add full support for CTRL+C on HLS and DASH --- devine/commands/dl.py | 2 + devine/core/manifests/dash.py | 87 +++++++++++++++++++-------------- devine/core/manifests/hls.py | 91 ++++++++++++++++++++--------------- 3 files changed, 104 insertions(+), 76 deletions(-) diff --git a/devine/commands/dl.py b/devine/commands/dl.py index ecf052c..f01533b 100644 --- a/devine/commands/dl.py +++ b/devine/commands/dl.py @@ -736,6 +736,7 @@ class dl: HLS.download_track( track=track, save_dir=save_dir, + stop_event=self.DL_POOL_STOP, progress=progress, session=service.session, proxy=proxy, @@ -745,6 +746,7 @@ class dl: DASH.download_track( track=track, save_dir=save_dir, + stop_event=self.DL_POOL_STOP, progress=progress, session=service.session, proxy=proxy, diff --git a/devine/core/manifests/dash.py b/devine/core/manifests/dash.py index 86ebde7..57edcfe 100644 --- a/devine/core/manifests/dash.py +++ b/devine/core/manifests/dash.py @@ -273,6 +273,7 @@ class DASH: def download_track( track: AnyTrack, save_dir: Path, + stop_event: Event, progress: partial, session: Optional[Session] = None, proxy: Optional[str] = None, @@ -445,10 +446,8 @@ class DASH: raise ValueError("license_widevine func must be supplied to use Widevine DRM") license_widevine(drm) - state_event = Event() - def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int: - if state_event.is_set(): + if stop_event.is_set(): return 0 segment_save_path = (save_dir / filename).with_suffix(".mp4") @@ -509,43 +508,57 @@ class DASH: last_speed_refresh = time.time() with ThreadPoolExecutor(max_workers=16) as pool: - try: - finished_threads = 0 - for download in futures.as_completed(( - pool.submit( - download_segment, - filename=str(i).zfill(len(str(len(segments)))), - segment=segment - ) - for i, segment in enumerate(segments) - )): - finished_threads += 1 - e = download.exception() - if e: - state_event.set() - traceback.print_exception(e) - log.error(f"Segment Download worker threw an unhandled exception: {e!r}") - sys.exit(1) - else: - progress(advance=1) + finished_threads = 0 + has_stopped = False + has_failed = False + for download in futures.as_completed(( + pool.submit( + download_segment, + filename=str(i).zfill(len(str(len(segments)))), + segment=segment + ) + for i, segment in enumerate(segments) + )): + finished_threads += 1 + try: + download_size = download.result() + except KeyboardInterrupt: + stop_event.set() + if not has_stopped: + has_stopped = True + progress(downloaded="[orange]STOPPING") + except Exception as e: + stop_event.set() + if has_stopped: + # we don't care because we were stopping anyway + continue + if not has_failed: + has_failed = True + progress(downloaded="[red]FAILING") + traceback.print_exception(e) + log.error(f"Segment Download worker threw an unhandled exception: {e!r}") + else: + if stop_event.is_set(): + # skipped + continue + progress(advance=1) - now = time.time() - time_since = now - last_speed_refresh + now = time.time() + time_since = now - last_speed_refresh - download_size = download.result() - if download_size: # no size == skipped dl - download_sizes.append(download_size) + if download_size: # no size == skipped dl + download_sizes.append(download_size) - if time_since > 5 or finished_threads == len(segments): - data_size = sum(download_sizes) - download_speed = data_size / time_since - progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s") - last_speed_refresh = now - download_sizes.clear() - except KeyboardInterrupt: - state_event.set() - log.info("Received Keyboard Interrupt, stopping...") - return + if time_since > 5 or finished_threads == len(segments): + data_size = sum(download_sizes) + download_speed = data_size / time_since + progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s") + last_speed_refresh = now + download_sizes.clear() + if has_failed: + progress(downloaded="[red]FAILED") + if has_stopped: + progress(downloaded="[yellow]STOPPED") @staticmethod def get_language(*options: Any) -> Optional[Language]: diff --git a/devine/core/manifests/hls.py b/devine/core/manifests/hls.py index 0c0b1ba..b38beee 100644 --- a/devine/core/manifests/hls.py +++ b/devine/core/manifests/hls.py @@ -182,6 +182,7 @@ class HLS: def download_track( track: AnyTrack, save_dir: Path, + stop_event: Event, progress: partial, session: Optional[Session] = None, proxy: Optional[str] = None, @@ -212,10 +213,8 @@ class HLS: log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.") sys.exit(1) - state_event = Event() - def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int: - if state_event.is_set(): + if stop_event.is_set(): return 0 segment_save_path = (save_dir / filename).with_suffix(".mp4") @@ -347,45 +346,59 @@ class HLS: last_speed_refresh = time.time() with ThreadPoolExecutor(max_workers=16) as pool: - try: - finished_threads = 0 - for download in futures.as_completed(( - pool.submit( - download_segment, - filename=str(i).zfill(len(str(len(master.segments)))), - segment=segment, - init_data=init_data, - segment_key=segment_key - ) - for i, segment in enumerate(master.segments) - )): - finished_threads += 1 - e = download.exception() - if e: - state_event.set() - traceback.print_exception(e) - log.error(f"Segment Download worker threw an unhandled exception: {e!r}") - sys.exit(1) - else: - progress(advance=1) + finished_threads = 0 + has_stopped = False + has_failed = False + for download in futures.as_completed(( + pool.submit( + download_segment, + filename=str(i).zfill(len(str(len(master.segments)))), + segment=segment, + init_data=init_data, + segment_key=segment_key + ) + for i, segment in enumerate(master.segments) + )): + finished_threads += 1 + try: + download_size = download.result() + except KeyboardInterrupt: + stop_event.set() + if not has_stopped: + has_stopped = True + progress(downloaded="[orange]STOPPING") + except Exception as e: + stop_event.set() + if has_stopped: + # we don't care because we were stopping anyway + continue + if not has_failed: + has_failed = True + progress(downloaded="[red]FAILING") + traceback.print_exception(e) + log.error(f"Segment Download worker threw an unhandled exception: {e!r}") + else: + if stop_event.is_set(): + # skipped + continue + progress(advance=1) - now = time.time() - time_since = now - last_speed_refresh + now = time.time() + time_since = now - last_speed_refresh - download_size = download.result() - if download_size: # no size == skipped dl - download_sizes.append(download_size) + if download_size: # no size == skipped dl + download_sizes.append(download_size) - if time_since > 5 or finished_threads == len(master.segments): - data_size = sum(download_sizes) - download_speed = data_size / time_since - progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s") - last_speed_refresh = now - download_sizes.clear() - except KeyboardInterrupt: - state_event.set() - log.info("Received Keyboard Interrupt, stopping...") - return + if time_since > 5 or finished_threads == len(master.segments): + data_size = sum(download_sizes) + download_speed = data_size / time_since + progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s") + last_speed_refresh = now + download_sizes.clear() + if has_failed: + progress(downloaded="[red]FAILED") + if has_stopped: + progress(downloaded="[yellow]STOPPED") @staticmethod def get_drm(