Shutdown HLS & DASH dl pool, pass exceptions to dl

This results in a noticeably faster speed cancelling segmented track downloads on CTRL+C and Errors. It's also reducing code duplication as the dl code will now handle the exception and cleanup for them.

This also simplifies the STOPPING/STOPPED and FAILING/FAILED status messages by quite a bit.
This commit is contained in:
rlaphoenix 2023-03-01 10:45:04 +00:00
parent fbe78308eb
commit 9f48aab80c
2 changed files with 72 additions and 98 deletions

View File

@ -4,10 +4,8 @@ import base64
import logging
import math
import re
import shutil
import sys
import time
import traceback
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from copy import copy
@ -450,7 +448,8 @@ class DASH:
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
if stop_event.is_set():
return 0
# the track already started downloading, but another failed or was stopped
raise KeyboardInterrupt()
segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -507,8 +506,6 @@ class DASH:
progress(total=len(segments))
finished_threads = 0
has_stopped = False
has_failed = False
download_sizes = []
last_speed_refresh = time.time()
@ -522,59 +519,49 @@ class DASH:
for i, segment in enumerate(segments)
)):
finished_threads += 1
try:
download_size = download.result()
except KeyboardInterrupt:
stop_event.set()
stop_event.set() # skip pending track downloads
progress(downloaded="[yellow]STOPPING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[yellow]STOPPED")
# tell dl that it was cancelled
# the pool is already shut down, so exiting loop is fine
raise
except Exception as e:
stop_event.set()
if has_stopped:
# we don't care because we were stopping anyway
continue
if not has_failed:
has_failed = True
progress(downloaded="[red]FAILING")
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
continue
stop_event.set() # skip pending track downloads
progress(downloaded="[red]FAILING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[red]FAILED")
# tell dl that it failed
# the pool is already shut down, so exiting loop is fine
raise e
else:
# it successfully downloaded, and it was not cancelled
progress(advance=1)
if stop_event.is_set():
if not has_stopped:
has_stopped = True
progress(downloaded="[orange]STOPPING")
continue
now = time.time()
time_since = now - last_speed_refresh
progress(advance=1)
if download_size: # no size == skipped dl
download_sizes.append(download_size)
now = time.time()
time_since = now - last_speed_refresh
if download_sizes and (time_since > 5 or finished_threads == len(segments)):
data_size = sum(download_sizes)
download_speed = data_size / time_since
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
if download_size: # no size == skipped dl
download_sizes.append(download_size)
with open(save_path, "wb") as f:
for segment_file in sorted(save_dir.iterdir()):
f.write(segment_file.read_bytes())
segment_file.unlink()
if download_sizes and (time_since > 5 or finished_threads == len(segments)):
data_size = sum(download_sizes)
download_speed = data_size / time_since
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
try:
if has_stopped:
progress(downloaded="[yellow]STOPPED")
return
if has_failed:
progress(downloaded="[red]FAILED")
return
with open(save_path, "wb") as f:
for segment_file in sorted(save_dir.iterdir()):
f.write(segment_file.read_bytes())
segment_file.unlink()
track.path = save_path
finally:
shutil.rmtree(save_dir)
track.path = save_path
save_dir.rmdir()
@staticmethod
def get_language(*options: Any) -> Optional[Language]:

View File

@ -2,10 +2,8 @@ from __future__ import annotations
import logging
import re
import shutil
import sys
import time
import traceback
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from functools import partial
@ -217,7 +215,8 @@ class HLS:
def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int:
if stop_event.is_set():
return 0
# the track already started downloading, but another failed or was stopped
raise KeyboardInterrupt()
segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -345,8 +344,6 @@ class HLS:
progress(total=len(master.segments))
finished_threads = 0
has_stopped = False
has_failed = False
download_sizes = []
last_speed_refresh = time.time()
@ -362,59 +359,49 @@ class HLS:
for i, segment in enumerate(master.segments)
)):
finished_threads += 1
try:
download_size = download.result()
except KeyboardInterrupt:
stop_event.set()
stop_event.set() # skip pending track downloads
progress(downloaded="[yellow]STOPPING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[yellow]STOPPED")
# tell dl that it was cancelled
# the pool is already shut down, so exiting loop is fine
raise
except Exception as e:
stop_event.set()
if has_stopped:
# we don't care because we were stopping anyway
continue
if not has_failed:
has_failed = True
progress(downloaded="[red]FAILING")
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
continue
stop_event.set() # skip pending track downloads
progress(downloaded="[red]FAILING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[red]FAILED")
# tell dl that it failed
# the pool is already shut down, so exiting loop is fine
raise e
else:
# it successfully downloaded, and it was not cancelled
progress(advance=1)
if stop_event.is_set():
if not has_stopped:
has_stopped = True
progress(downloaded="[orange]STOPPING")
continue
now = time.time()
time_since = now - last_speed_refresh
progress(advance=1)
if download_size: # no size == skipped dl
download_sizes.append(download_size)
now = time.time()
time_since = now - last_speed_refresh
if download_sizes and (time_since > 5 or finished_threads == len(master.segments)):
data_size = sum(download_sizes)
download_speed = data_size / time_since
progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
if download_size: # no size == skipped dl
download_sizes.append(download_size)
with open(save_path, "wb") as f:
for segment_file in sorted(save_dir.iterdir()):
f.write(segment_file.read_bytes())
segment_file.unlink()
if download_sizes and (time_since > 5 or finished_threads == len(master.segments)):
data_size = sum(download_sizes)
download_speed = data_size / time_since
progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
try:
if has_stopped:
progress(downloaded="[yellow]STOPPED")
return
if has_failed:
progress(downloaded="[red]FAILED")
return
with open(save_path, "wb") as f:
for segment_file in sorted(save_dir.iterdir()):
f.write(segment_file.read_bytes())
segment_file.unlink()
track.path = save_path
finally:
shutil.rmtree(save_dir)
track.path = save_path
save_dir.rmdir()
@staticmethod
def get_drm(