Shutdown HLS & DASH dl pool, pass exceptions to dl

This results in a noticeably faster speed cancelling segmented track downloads on CTRL+C and Errors. It's also reducing code duplication as the dl code will now handle the exception and cleanup for them.

This also simplifies the STOPPING/STOPPED and FAILING/FAILED status messages by quite a bit.
This commit is contained in:
rlaphoenix 2023-03-01 10:45:04 +00:00
parent fbe78308eb
commit 9f48aab80c
2 changed files with 72 additions and 98 deletions

View File

@ -4,10 +4,8 @@ import base64
import logging import logging
import math import math
import re import re
import shutil
import sys import sys
import time import time
import traceback
from concurrent import futures from concurrent import futures
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from copy import copy from copy import copy
@ -450,7 +448,8 @@ class DASH:
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int: def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
if stop_event.is_set(): if stop_event.is_set():
return 0 # the track already started downloading, but another failed or was stopped
raise KeyboardInterrupt()
segment_save_path = (save_dir / filename).with_suffix(".mp4") segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -507,8 +506,6 @@ class DASH:
progress(total=len(segments)) progress(total=len(segments))
finished_threads = 0 finished_threads = 0
has_stopped = False
has_failed = False
download_sizes = [] download_sizes = []
last_speed_refresh = time.time() last_speed_refresh = time.time()
@ -522,59 +519,49 @@ class DASH:
for i, segment in enumerate(segments) for i, segment in enumerate(segments)
)): )):
finished_threads += 1 finished_threads += 1
try: try:
download_size = download.result() download_size = download.result()
except KeyboardInterrupt: except KeyboardInterrupt:
stop_event.set() stop_event.set() # skip pending track downloads
progress(downloaded="[yellow]STOPPING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[yellow]STOPPED")
# tell dl that it was cancelled
# the pool is already shut down, so exiting loop is fine
raise
except Exception as e: except Exception as e:
stop_event.set() stop_event.set() # skip pending track downloads
if has_stopped: progress(downloaded="[red]FAILING")
# we don't care because we were stopping anyway pool.shutdown(wait=True, cancel_futures=True)
continue progress(downloaded="[red]FAILED")
if not has_failed: # tell dl that it failed
has_failed = True # the pool is already shut down, so exiting loop is fine
progress(downloaded="[red]FAILING") raise e
traceback.print_exception(e) else:
log.error(f"Segment Download worker threw an unhandled exception: {e!r}") # it successfully downloaded, and it was not cancelled
continue progress(advance=1)
if stop_event.is_set(): now = time.time()
if not has_stopped: time_since = now - last_speed_refresh
has_stopped = True
progress(downloaded="[orange]STOPPING")
continue
progress(advance=1) if download_size: # no size == skipped dl
download_sizes.append(download_size)
now = time.time() if download_sizes and (time_since > 5 or finished_threads == len(segments)):
time_since = now - last_speed_refresh data_size = sum(download_sizes)
download_speed = data_size / time_since
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
if download_size: # no size == skipped dl with open(save_path, "wb") as f:
download_sizes.append(download_size) for segment_file in sorted(save_dir.iterdir()):
f.write(segment_file.read_bytes())
segment_file.unlink()
if download_sizes and (time_since > 5 or finished_threads == len(segments)): track.path = save_path
data_size = sum(download_sizes) save_dir.rmdir()
download_speed = data_size / time_since
progress(downloaded=f"DASH {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
try:
if has_stopped:
progress(downloaded="[yellow]STOPPED")
return
if has_failed:
progress(downloaded="[red]FAILED")
return
with open(save_path, "wb") as f:
for segment_file in sorted(save_dir.iterdir()):
f.write(segment_file.read_bytes())
segment_file.unlink()
track.path = save_path
finally:
shutil.rmtree(save_dir)
@staticmethod @staticmethod
def get_language(*options: Any) -> Optional[Language]: def get_language(*options: Any) -> Optional[Language]:

View File

@ -2,10 +2,8 @@ from __future__ import annotations
import logging import logging
import re import re
import shutil
import sys import sys
import time import time
import traceback
from concurrent import futures from concurrent import futures
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial from functools import partial
@ -217,7 +215,8 @@ class HLS:
def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int: def download_segment(filename: str, segment: m3u8.Segment, init_data: Queue, segment_key: Queue) -> int:
if stop_event.is_set(): if stop_event.is_set():
return 0 # the track already started downloading, but another failed or was stopped
raise KeyboardInterrupt()
segment_save_path = (save_dir / filename).with_suffix(".mp4") segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -345,8 +344,6 @@ class HLS:
progress(total=len(master.segments)) progress(total=len(master.segments))
finished_threads = 0 finished_threads = 0
has_stopped = False
has_failed = False
download_sizes = [] download_sizes = []
last_speed_refresh = time.time() last_speed_refresh = time.time()
@ -362,59 +359,49 @@ class HLS:
for i, segment in enumerate(master.segments) for i, segment in enumerate(master.segments)
)): )):
finished_threads += 1 finished_threads += 1
try: try:
download_size = download.result() download_size = download.result()
except KeyboardInterrupt: except KeyboardInterrupt:
stop_event.set() stop_event.set() # skip pending track downloads
progress(downloaded="[yellow]STOPPING")
pool.shutdown(wait=True, cancel_futures=True)
progress(downloaded="[yellow]STOPPED")
# tell dl that it was cancelled
# the pool is already shut down, so exiting loop is fine
raise
except Exception as e: except Exception as e:
stop_event.set() stop_event.set() # skip pending track downloads
if has_stopped: progress(downloaded="[red]FAILING")
# we don't care because we were stopping anyway pool.shutdown(wait=True, cancel_futures=True)
continue progress(downloaded="[red]FAILED")
if not has_failed: # tell dl that it failed
has_failed = True # the pool is already shut down, so exiting loop is fine
progress(downloaded="[red]FAILING") raise e
traceback.print_exception(e) else:
log.error(f"Segment Download worker threw an unhandled exception: {e!r}") # it successfully downloaded, and it was not cancelled
continue progress(advance=1)
if stop_event.is_set(): now = time.time()
if not has_stopped: time_since = now - last_speed_refresh
has_stopped = True
progress(downloaded="[orange]STOPPING")
continue
progress(advance=1) if download_size: # no size == skipped dl
download_sizes.append(download_size)
now = time.time() if download_sizes and (time_since > 5 or finished_threads == len(master.segments)):
time_since = now - last_speed_refresh data_size = sum(download_sizes)
download_speed = data_size / time_since
progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
if download_size: # no size == skipped dl with open(save_path, "wb") as f:
download_sizes.append(download_size) for segment_file in sorted(save_dir.iterdir()):
f.write(segment_file.read_bytes())
segment_file.unlink()
if download_sizes and (time_since > 5 or finished_threads == len(master.segments)): track.path = save_path
data_size = sum(download_sizes) save_dir.rmdir()
download_speed = data_size / time_since
progress(downloaded=f"HLS {filesize.decimal(download_speed)}/s")
last_speed_refresh = now
download_sizes.clear()
try:
if has_stopped:
progress(downloaded="[yellow]STOPPED")
return
if has_failed:
progress(downloaded="[red]FAILED")
return
with open(save_path, "wb") as f:
for segment_file in sorted(save_dir.iterdir()):
f.write(segment_file.read_bytes())
segment_file.unlink()
track.path = save_path
finally:
shutil.rmtree(save_dir)
@staticmethod @staticmethod
def get_drm( def get_drm(