Add full support for CTRL+C on HLS and DASH

This commit is contained in:
rlaphoenix 2023-02-28 16:54:52 +00:00
parent 8365d798a4
commit 383e7d9647
3 changed files with 104 additions and 76 deletions

View File

@ -736,6 +736,7 @@ class dl:
HLS.download_track( HLS.download_track(
track=track, track=track,
save_dir=save_dir, save_dir=save_dir,
stop_event=self.DL_POOL_STOP,
progress=progress, progress=progress,
session=service.session, session=service.session,
proxy=proxy, proxy=proxy,
@ -745,6 +746,7 @@ class dl:
DASH.download_track( DASH.download_track(
track=track, track=track,
save_dir=save_dir, save_dir=save_dir,
stop_event=self.DL_POOL_STOP,
progress=progress, progress=progress,
session=service.session, session=service.session,
proxy=proxy, proxy=proxy,

View File

@ -273,6 +273,7 @@ class DASH:
def download_track( def download_track(
track: AnyTrack, track: AnyTrack,
save_dir: Path, save_dir: Path,
stop_event: Event,
progress: partial, progress: partial,
session: Optional[Session] = None, session: Optional[Session] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
@ -445,10 +446,8 @@ class DASH:
raise ValueError("license_widevine func must be supplied to use Widevine DRM") raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm) license_widevine(drm)
state_event = Event()
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int: def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
if state_event.is_set(): if stop_event.is_set():
return 0 return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4") segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -509,43 +508,57 @@ class DASH:
last_speed_refresh = time.time() last_speed_refresh = time.time()
with ThreadPoolExecutor(max_workers=16) as pool: with ThreadPoolExecutor(max_workers=16) as pool:
try: finished_threads = 0
finished_threads = 0 has_stopped = False
for download in futures.as_completed(( has_failed = False
pool.submit( for download in futures.as_completed((
download_segment, pool.submit(
filename=str(i).zfill(len(str(len(segments)))), download_segment,
segment=segment filename=str(i).zfill(len(str(len(segments)))),
) segment=segment
for i, segment in enumerate(segments) )
)): for i, segment in enumerate(segments)
finished_threads += 1 )):
e = download.exception() finished_threads += 1
if e: try:
state_event.set() download_size = download.result()
traceback.print_exception(e) except KeyboardInterrupt:
log.error(f"Segment Download worker threw an unhandled exception: {e!r}") stop_event.set()
sys.exit(1) if not has_stopped:
else: has_stopped = True
progress(advance=1) progress(downloaded="[orange]STOPPING")
except Exception as e:
stop_event.set()
if has_stopped:
# we don't care because we were stopping anyway
continue
if not has_failed:
has_failed = True
progress(downloaded="[red]FAILING")
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
else:
if stop_event.is_set():
# skipped
continue
progress(advance=1)
now = time.time() now = time.time()
time_since = now - last_speed_refresh time_since = now - last_speed_refresh
download_size = download.result() if download_size: # no size == skipped dl
if download_size: # no size == skipped dl download_sizes.append(download_size)
download_sizes.append(download_size)
if time_since > 5 or finished_threads == len(segments): if time_since > 5 or finished_threads == len(segments):
data_size = sum(download_sizes) data_size = sum(download_sizes)
download_speed = data_size / time_since download_speed = data_size / time_since
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s") progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
last_speed_refresh = now last_speed_refresh = now
download_sizes.clear() download_sizes.clear()
except KeyboardInterrupt: if has_failed:
state_event.set() progress(downloaded="[red]FAILED")
log.info("Received Keyboard Interrupt, stopping...") if has_stopped:
return progress(downloaded="[yellow]STOPPED")
@staticmethod @staticmethod
def get_language(*options: Any) -> Optional[Language]: def get_language(*options: Any) -> Optional[Language]:

View File

@ -182,6 +182,7 @@ class HLS:
def download_track( def download_track(
track: AnyTrack, track: AnyTrack,
save_dir: Path, save_dir: Path,
stop_event: Event,
progress: partial, progress: partial,
session: Optional[Session] = None, session: Optional[Session] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
@ -212,10 +213,8 @@ class HLS:
log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.") log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
sys.exit(1) sys.exit(1)
state_event = Event()
def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int: def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int:
if state_event.is_set(): if stop_event.is_set():
return 0 return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4") segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -347,45 +346,59 @@ class HLS:
last_speed_refresh = time.time() last_speed_refresh = time.time()
with ThreadPoolExecutor(max_workers=16) as pool: with ThreadPoolExecutor(max_workers=16) as pool:
try: finished_threads = 0
finished_threads = 0 has_stopped = False
for download in futures.as_completed(( has_failed = False
pool.submit( for download in futures.as_completed((
download_segment, pool.submit(
filename=str(i).zfill(len(str(len(master.segments)))), download_segment,
segment=segment, filename=str(i).zfill(len(str(len(master.segments)))),
init_data=init_data, segment=segment,
segment_key=segment_key init_data=init_data,
) segment_key=segment_key
for i, segment in enumerate(master.segments) )
)): for i, segment in enumerate(master.segments)
finished_threads += 1 )):
e = download.exception() finished_threads += 1
if e: try:
state_event.set() download_size = download.result()
traceback.print_exception(e) except KeyboardInterrupt:
log.error(f"Segment Download worker threw an unhandled exception: {e!r}") stop_event.set()
sys.exit(1) if not has_stopped:
else: has_stopped = True
progress(advance=1) progress(downloaded="[orange]STOPPING")
except Exception as e:
stop_event.set()
if has_stopped:
# we don't care because we were stopping anyway
continue
if not has_failed:
has_failed = True
progress(downloaded="[red]FAILING")
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
else:
if stop_event.is_set():
# skipped
continue
progress(advance=1)
now = time.time() now = time.time()
time_since = now - last_speed_refresh time_since = now - last_speed_refresh
download_size = download.result() if download_size: # no size == skipped dl
if download_size: # no size == skipped dl download_sizes.append(download_size)
download_sizes.append(download_size)
if time_since > 5 or finished_threads == len(master.segments): if time_since > 5 or finished_threads == len(master.segments):
data_size = sum(download_sizes) data_size = sum(download_sizes)
download_speed = data_size / time_since download_speed = data_size / time_since
progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s") progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
last_speed_refresh = now last_speed_refresh = now
download_sizes.clear() download_sizes.clear()
except KeyboardInterrupt: if has_failed:
state_event.set() progress(downloaded="[red]FAILED")
log.info("Received Keyboard Interrupt, stopping...") if has_stopped:
return progress(downloaded="[yellow]STOPPED")
@staticmethod @staticmethod
def get_drm( def get_drm(