Add full support for CTRL+C on HLS and DASH

This commit is contained in:
rlaphoenix 2023-02-28 16:54:52 +00:00
parent 8365d798a4
commit 383e7d9647
3 changed files with 104 additions and 76 deletions

View File

@ -736,6 +736,7 @@ class dl:
HLS.download_track( HLS.download_track(
track=track, track=track,
save_dir=save_dir, save_dir=save_dir,
stop_event=self.DL_POOL_STOP,
progress=progress, progress=progress,
session=service.session, session=service.session,
proxy=proxy, proxy=proxy,
@ -745,6 +746,7 @@ class dl:
DASH.download_track( DASH.download_track(
track=track, track=track,
save_dir=save_dir, save_dir=save_dir,
stop_event=self.DL_POOL_STOP,
progress=progress, progress=progress,
session=service.session, session=service.session,
proxy=proxy, proxy=proxy,

View File

@ -273,6 +273,7 @@ class DASH:
def download_track( def download_track(
track: AnyTrack, track: AnyTrack,
save_dir: Path, save_dir: Path,
stop_event: Event,
progress: partial, progress: partial,
session: Optional[Session] = None, session: Optional[Session] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
@ -445,10 +446,8 @@ class DASH:
raise ValueError("license_widevine func must be supplied to use Widevine DRM") raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm) license_widevine(drm)
state_event = Event()
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int: def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
if state_event.is_set(): if stop_event.is_set():
return 0 return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4") segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -509,8 +508,9 @@ class DASH:
last_speed_refresh = time.time() last_speed_refresh = time.time()
with ThreadPoolExecutor(max_workers=16) as pool: with ThreadPoolExecutor(max_workers=16) as pool:
try:
finished_threads = 0 finished_threads = 0
has_stopped = False
has_failed = False
for download in futures.as_completed(( for download in futures.as_completed((
pool.submit( pool.submit(
download_segment, download_segment,
@ -520,19 +520,32 @@ class DASH:
for i, segment in enumerate(segments) for i, segment in enumerate(segments)
)): )):
finished_threads += 1 finished_threads += 1
e = download.exception() try:
if e: download_size = download.result()
state_event.set() except KeyboardInterrupt:
stop_event.set()
if not has_stopped:
has_stopped = True
progress(downloaded="[orange]STOPPING")
except Exception as e:
stop_event.set()
if has_stopped:
# we don't care because we were stopping anyway
continue
if not has_failed:
has_failed = True
progress(downloaded="[red]FAILING")
traceback.print_exception(e) traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}") log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else: else:
if stop_event.is_set():
# skipped
continue
progress(advance=1) progress(advance=1)
now = time.time() now = time.time()
time_since = now - last_speed_refresh time_since = now - last_speed_refresh
download_size = download.result()
if download_size: # no size == skipped dl if download_size: # no size == skipped dl
download_sizes.append(download_size) download_sizes.append(download_size)
@ -542,10 +555,10 @@ class DASH:
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s") progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
last_speed_refresh = now last_speed_refresh = now
download_sizes.clear() download_sizes.clear()
except KeyboardInterrupt: if has_failed:
state_event.set() progress(downloaded="[red]FAILED")
log.info("Received Keyboard Interrupt, stopping...") if has_stopped:
return progress(downloaded="[yellow]STOPPED")
@staticmethod @staticmethod
def get_language(*options: Any) -> Optional[Language]: def get_language(*options: Any) -> Optional[Language]:

View File

@ -182,6 +182,7 @@ class HLS:
def download_track( def download_track(
track: AnyTrack, track: AnyTrack,
save_dir: Path, save_dir: Path,
stop_event: Event,
progress: partial, progress: partial,
session: Optional[Session] = None, session: Optional[Session] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
@ -212,10 +213,8 @@ class HLS:
log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.") log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
sys.exit(1) sys.exit(1)
state_event = Event()
def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int: def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int:
if state_event.is_set(): if stop_event.is_set():
return 0 return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4") segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -347,8 +346,9 @@ class HLS:
last_speed_refresh = time.time() last_speed_refresh = time.time()
with ThreadPoolExecutor(max_workers=16) as pool: with ThreadPoolExecutor(max_workers=16) as pool:
try:
finished_threads = 0 finished_threads = 0
has_stopped = False
has_failed = False
for download in futures.as_completed(( for download in futures.as_completed((
pool.submit( pool.submit(
download_segment, download_segment,
@ -360,19 +360,32 @@ class HLS:
for i, segment in enumerate(master.segments) for i, segment in enumerate(master.segments)
)): )):
finished_threads += 1 finished_threads += 1
e = download.exception() try:
if e: download_size = download.result()
state_event.set() except KeyboardInterrupt:
stop_event.set()
if not has_stopped:
has_stopped = True
progress(downloaded="[orange]STOPPING")
except Exception as e:
stop_event.set()
if has_stopped:
# we don't care because we were stopping anyway
continue
if not has_failed:
has_failed = True
progress(downloaded="[red]FAILING")
traceback.print_exception(e) traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}") log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else: else:
if stop_event.is_set():
# skipped
continue
progress(advance=1) progress(advance=1)
now = time.time() now = time.time()
time_since = now - last_speed_refresh time_since = now - last_speed_refresh
download_size = download.result()
if download_size: # no size == skipped dl if download_size: # no size == skipped dl
download_sizes.append(download_size) download_sizes.append(download_size)
@ -382,10 +395,10 @@ class HLS:
progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s") progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
last_speed_refresh = now last_speed_refresh = now
download_sizes.clear() download_sizes.clear()
except KeyboardInterrupt: if has_failed:
state_event.set() progress(downloaded="[red]FAILED")
log.info("Received Keyboard Interrupt, stopping...") if has_stopped:
return progress(downloaded="[yellow]STOPPED")
@staticmethod @staticmethod
def get_drm( def get_drm(