Add full support for CTRL+C on HLS and DASH

This commit is contained in:
rlaphoenix 2023-02-28 16:54:52 +00:00
parent 8365d798a4
commit 383e7d9647
3 changed files with 104 additions and 76 deletions

View File

@ -736,6 +736,7 @@ class dl:
HLS.download_track(
track=track,
save_dir=save_dir,
stop_event=self.DL_POOL_STOP,
progress=progress,
session=service.session,
proxy=proxy,
@ -745,6 +746,7 @@ class dl:
DASH.download_track(
track=track,
save_dir=save_dir,
stop_event=self.DL_POOL_STOP,
progress=progress,
session=service.session,
proxy=proxy,

View File

@ -273,6 +273,7 @@ class DASH:
def download_track(
track: AnyTrack,
save_dir: Path,
stop_event: Event,
progress: partial,
session: Optional[Session] = None,
proxy: Optional[str] = None,
@ -445,10 +446,8 @@ class DASH:
raise ValueError("license_widevine func must be supplied to use Widevine DRM")
license_widevine(drm)
state_event = Event()
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
if state_event.is_set():
if stop_event.is_set():
return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -509,43 +508,57 @@ class DASH:
last_speed_refresh = time.time()
with ThreadPoolExecutor(max_workers=16) as pool:
try:
finished_threads = 0
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(segments)))),
segment=segment
)
for i, segment in enumerate(segments)
)):
finished_threads += 1
e = download.exception()
if e:
state_event.set()
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
progress(advance=1)
finished_threads = 0
has_stopped = False
has_failed = False
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(segments)))),
segment=segment
)
for i, segment in enumerate(segments)
)):
finished_threads += 1
try:
download_size = download.result()
except KeyboardInterrupt:
stop_event.set()
if not has_stopped:
has_stopped = True
progress(downloaded="[orange]STOPPING")
except Exception as e:
stop_event.set()
if has_stopped:
# we don't care because we were stopping anyway
continue
if not has_failed:
has_failed = True
progress(downloaded="[red]FAILING")
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
else:
if stop_event.is_set():
# skipped
continue
progress(advance=1)
now = time.time()
time_since = now - last_speed_refresh
now = time.time()
time_since = now - last_speed_refresh
download_size = download.result()
if download_size: # no size == skipped dl
download_sizes.append(download_size)
if download_size: # no size == skipped dl
download_sizes.append(download_size)
if time_since > 5 or finished_threads == len(segments):
data_size = sum(download_sizes)
download_speed = data_size / time_since
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
except KeyboardInterrupt:
state_event.set()
log.info("Received Keyboard Interrupt, stopping...")
return
if time_since > 5 or finished_threads == len(segments):
data_size = sum(download_sizes)
download_speed = data_size / time_since
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
if has_failed:
progress(downloaded="[red]FAILED")
if has_stopped:
progress(downloaded="[yellow]STOPPED")
@staticmethod
def get_language(*options: Any) -> Optional[Language]:

View File

@ -182,6 +182,7 @@ class HLS:
def download_track(
track: AnyTrack,
save_dir: Path,
stop_event: Event,
progress: partial,
session: Optional[Session] = None,
proxy: Optional[str] = None,
@ -212,10 +213,8 @@ class HLS:
log.error("Track's HLS playlist has no segments, expecting an invariant M3U8 playlist.")
sys.exit(1)
state_event = Event()
def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int:
if state_event.is_set():
if stop_event.is_set():
return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -347,45 +346,59 @@ class HLS:
last_speed_refresh = time.time()
with ThreadPoolExecutor(max_workers=16) as pool:
try:
finished_threads = 0
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(master.segments)))),
segment=segment,
init_data=init_data,
segment_key=segment_key
)
for i, segment in enumerate(master.segments)
)):
finished_threads += 1
e = download.exception()
if e:
state_event.set()
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
progress(advance=1)
finished_threads = 0
has_stopped = False
has_failed = False
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(master.segments)))),
segment=segment,
init_data=init_data,
segment_key=segment_key
)
for i, segment in enumerate(master.segments)
)):
finished_threads += 1
try:
download_size = download.result()
except KeyboardInterrupt:
stop_event.set()
if not has_stopped:
has_stopped = True
progress(downloaded="[orange]STOPPING")
except Exception as e:
stop_event.set()
if has_stopped:
# we don't care because we were stopping anyway
continue
if not has_failed:
has_failed = True
progress(downloaded="[red]FAILING")
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
else:
if stop_event.is_set():
# skipped
continue
progress(advance=1)
now = time.time()
time_since = now - last_speed_refresh
now = time.time()
time_since = now - last_speed_refresh
download_size = download.result()
if download_size: # no size == skipped dl
download_sizes.append(download_size)
if download_size: # no size == skipped dl
download_sizes.append(download_size)
if time_since > 5 or finished_threads == len(master.segments):
data_size = sum(download_sizes)
download_speed = data_size / time_since
progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
except KeyboardInterrupt:
state_event.set()
log.info("Received Keyboard Interrupt, stopping...")
return
if time_since > 5 or finished_threads == len(master.segments):
data_size = sum(download_sizes)
download_speed = data_size / time_since
progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
if has_failed:
progress(downloaded="[red]FAILED")
if has_stopped:
progress(downloaded="[yellow]STOPPED")
@staticmethod
def get_drm(