From 92895426b31366afc97004d2c86dc7d9ef8c0f62 Mon Sep 17 00:00:00 2001 From: rlaphoenix Date: Sat, 25 Feb 2023 13:45:17 +0000 Subject: [PATCH] Replace tqdm progress bars with rich progress bars --- devine/commands/dl.py | 33 ++++++----- devine/core/downloaders/aria2c.py | 56 ++++++++++++++++++- devine/core/manifests/dash.py | 78 ++++++++++++++++---------- devine/core/manifests/hls.py | 91 ++++++++++++++++++------------- devine/core/tracks/tracks.py | 31 ++++++++++- 5 files changed, 202 insertions(+), 87 deletions(-) diff --git a/devine/commands/dl.py b/devine/commands/dl.py index 098a6ad..a04ebb2 100644 --- a/devine/commands/dl.py +++ b/devine/commands/dl.py @@ -27,7 +27,7 @@ from pymediainfo import MediaInfo from pywidevine.cdm import Cdm as WidevineCdm from pywidevine.device import Device from pywidevine.remotecdm import RemoteCdm -from tqdm import tqdm +from rich.live import Live from rich.padding import Padding from rich.rule import Rule @@ -332,7 +332,7 @@ class dl: title.tracks.sort_chapters() if list_: - available_tracks = title.tracks.tree() + available_tracks, _ = title.tracks.tree() console.log(available_tracks) continue @@ -413,13 +413,19 @@ class dl: if not subs_only: title.tracks.subtitles.clear() - selected_tracks = title.tracks.tree() - console.log(selected_tracks) + selected_tracks, tracks_progress_callables = title.tracks.tree(add_progress=True) if skip_dl: console.log("Skipping Download...") else: - with tqdm(total=len(title.tracks)) as pbar: + with Live( + Padding( + selected_tracks, + (0, 5, 1, 5) + ), + console=console, + refresh_per_second=5 + ): with ThreadPoolExecutor(workers) as pool: try: for download in futures.as_completed(( @@ -445,9 +451,10 @@ class dl: cdm_only=cdm_only, vaults_only=vaults_only, export=export - ) + ), + progress=tracks_progress_callables[i] ) - for track in title.tracks + for i, track in enumerate(title.tracks) )): if download.cancelled(): continue @@ -458,8 +465,6 @@ class dl: traceback.print_exception(type(e), e, e.__traceback__) self.log.error(f"Download worker threw an unhandled exception: {e!r}") return - else: - pbar.update(1) except KeyboardInterrupt: self.DL_POOL_STOP.set() pool.shutdown(wait=False, cancel_futures=True) @@ -570,7 +575,8 @@ class dl: service: Service, track: AnyTrack, title: Title_T, - prepare_drm: Callable + prepare_drm: Callable, + progress: partial ): time.sleep(1) if self.DL_POOL_STOP.is_set(): @@ -581,8 +587,6 @@ class dl: else: proxy = None - console.log(f"Downloading: {track}") - if config.directories.temp.is_file(): self.log.error(f"Temp Directory '{config.directories.temp}' must be a Directory, not a file") sys.exit(1) @@ -612,6 +616,7 @@ class dl: HLS.download_track( track=track, save_dir=save_dir, + progress=progress, session=service.session, proxy=proxy, license_widevine=prepare_drm @@ -620,6 +625,7 @@ class dl: DASH.download_track( track=track, save_dir=save_dir, + progress=progress, session=service.session, proxy=proxy, license_widevine=prepare_drm @@ -631,7 +637,8 @@ class dl: track.url, save_path, service.session.headers, - proxy if track.needs_proxy else None + proxy if track.needs_proxy else None, + progress=progress )) track.path = save_path diff --git a/devine/core/downloaders/aria2c.py b/devine/core/downloaders/aria2c.py index 3efbe52..22cdccf 100644 --- a/devine/core/downloaders/aria2c.py +++ b/devine/core/downloaders/aria2c.py @@ -1,5 +1,8 @@ import asyncio import subprocess +import sys +from asyncio import IncompleteReadError +from functools import partial from pathlib import Path from typing import Optional, Union @@ -13,6 +16,7 @@ async def aria2c( headers: Optional[dict] = None, proxy: Optional[str] = None, silent: bool = False, + progress: Optional[partial] = None, *args: str ) -> int: """ @@ -59,7 +63,7 @@ async def aria2c( "--summary-interval", "0", "--file-allocation", config.aria2c.get("file_allocation", "falloc"), "--console-log-level", "warn", - "--download-result", "hide", + "--download-result", ["hide", "default"][bool(progress)], *args, "-i", "-" ] @@ -84,9 +88,55 @@ async def aria2c( *arguments, stdin=subprocess.PIPE, stderr=[None, subprocess.DEVNULL][silent], - stdout=[None, subprocess.DEVNULL][silent] + stdout=( + subprocess.PIPE if progress else + subprocess.DEVNULL if silent else + None + ) ) - await p.communicate(uri.encode()) + + p.stdin.write(uri.encode()) + await p.stdin.drain() + p.stdin.close() + + if progress: + # I'm sorry for this shameful code, aria2(c) is annoying as f!!! + buffer = b"" + recording = False + while not p.stdout.at_eof(): + try: + byte = await p.stdout.readexactly(1) + except IncompleteReadError: + pass # ignore, the first read will do this + else: + if byte == b"=": # download result log + progress(total=100, completed=100) + break + if byte == b"[": + recording = True + if recording: + buffer += byte + if byte == b"]": + recording = False + if b"FileAlloc" not in buffer: + try: + # id, dledMiB/totalMiB(x%), CN:xx, DL:xxMiB, ETA:Xs + # eta may not always be available + parts = buffer.decode()[1:-1].split() + dl_parts = parts[1].split("(") + if len(dl_parts) == 2: + # might otherwise be e.g., 0B/0B, with no % symbol provided + progress( + total=100, + completed=int(dl_parts[1][:-2]), + downloaded=f"{parts[3].split(':')[1]}/s" + ) + except Exception as e: + print(f"Aria2c progress failed on {buffer}, {e!r}") + sys.exit(1) + buffer = b"" + + await p.wait() if p.returncode != 0: raise subprocess.CalledProcessError(p.returncode, arguments) diff --git a/devine/core/manifests/dash.py b/devine/core/manifests/dash.py index 3ba5565..9b794cc 100644 --- a/devine/core/manifests/dash.py +++ b/devine/core/manifests/dash.py @@ -11,6 +11,7 @@ import traceback from concurrent import futures from concurrent.futures import ThreadPoolExecutor from copy import copy +from functools import partial from hashlib import md5 from pathlib import Path from threading import Event @@ -23,7 +24,7 @@ from langcodes import Language, tag_is_valid from pywidevine.cdm import Cdm as WidevineCdm from pywidevine.pssh import PSSH from requests import Session -from tqdm import tqdm +from rich import filesize from devine.core.console import console from devine.core.constants import AnyTrack @@ -274,6 +275,7 @@ class DASH: def download_track( track: AnyTrack, save_dir: Path, + progress: partial, session: Optional[Session] = None, proxy: Optional[str] = None, license_widevine: Optional[Callable] = None @@ -447,10 +449,10 @@ class DASH: state_event = Event() - def download_segment(filename: str, segment: tuple[str, Optional[str]]): + def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int: time.sleep(0.1) if state_event.is_set(): - return + return 0 segment_save_path = (save_dir / filename).with_suffix(".mp4") @@ -476,6 +478,8 @@ class DASH: silent=True )) + data_size = len(init_data or b"") + if isinstance(track, Audio) or init_data: with open(segment_save_path, "rb+") as f: segment_data = f.read() @@ -492,6 +496,7 @@ class DASH: f.seek(0) f.write(init_data) f.write(segment_data) + data_size += len(segment_data) if drm: # TODO: What if the manifest does not mention DRM, but has DRM @@ -500,33 +505,48 @@ class DASH: if callable(track.OnDecrypted): track.OnDecrypted(track) - with tqdm(total=len(segments), unit="segments") as pbar: - with ThreadPoolExecutor(max_workers=16) as pool: - try: - for download in futures.as_completed(( - pool.submit( - download_segment, - filename=str(i).zfill(len(str(len(segments)))), - segment=segment + return data_size + + progress(total=len(segments)) + + download_start_time = time.time() + download_sizes = [] + + with ThreadPoolExecutor(max_workers=16) as pool: + try: + for download in futures.as_completed(( + pool.submit( + download_segment, + filename=str(i).zfill(len(str(len(segments)))), + segment=segment + ) + for i, segment in enumerate(segments) + )): + if download.cancelled(): + continue + e = download.exception() + if e: + state_event.set() + pool.shutdown(wait=False, cancel_futures=True) + traceback.print_exception(e) + log.error(f"Segment Download worker threw an unhandled exception: {e!r}") + sys.exit(1) + else: + download_size = download.result() + elapsed_time = time.time() - download_start_time + download_sizes.append(download_size) + while elapsed_time - len(download_sizes) > 10: + download_sizes.pop(0) + download_speed = sum(download_sizes) / len(download_sizes) + progress( + advance=1, + downloaded=f"DASH {filesize.decimal(download_speed)}/s" ) - for i, segment in enumerate(segments) - )): - if download.cancelled(): - continue - e = download.exception() - if e: - state_event.set() - pool.shutdown(wait=False, cancel_futures=True) - traceback.print_exception(e) - log.error(f"Segment Download worker threw an unhandled exception: {e!r}") - sys.exit(1) - else: - pbar.update(1) - except KeyboardInterrupt: - state_event.set() - pool.shutdown(wait=False, cancel_futures=True) - console.log("Received Keyboard Interrupt, stopping...") - return + except KeyboardInterrupt: + state_event.set() + pool.shutdown(wait=False, cancel_futures=True) + console.log("Received Keyboard Interrupt, stopping...") + return @staticmethod def get_language(*options: Any) -> Optional[Language]: diff --git a/devine/core/manifests/hls.py b/devine/core/manifests/hls.py index c71993b..495fb32 100644 --- a/devine/core/manifests/hls.py +++ b/devine/core/manifests/hls.py @@ -8,6 +8,7 @@ import time import traceback from concurrent import futures from concurrent.futures import ThreadPoolExecutor +from functools import partial from hashlib import md5 from pathlib import Path from queue import Queue @@ -21,7 +22,7 @@ from m3u8 import M3U8 from pywidevine.cdm import Cdm as WidevineCdm from pywidevine.pssh import PSSH from requests import Session -from tqdm import tqdm +from rich import filesize from devine.core.console import console from devine.core.constants import AnyTrack @@ -183,6 +184,7 @@ class HLS: def download_track( track: AnyTrack, save_dir: Path, + progress: partial, session: Optional[Session] = None, proxy: Optional[str] = None, license_widevine: Optional[Callable] = None @@ -214,16 +216,10 @@ class HLS: state_event = Event() - def download_segment( - filename: str, - segment: m3u8.Segment, - init_data: Queue, - segment_key: Queue, - range_offset: Queue - ) -> None: + def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue) -> int: time.sleep(0.1) if state_event.is_set(): - return + return 0 segment_save_path = (save_dir / filename).with_suffix(".mp4") @@ -255,7 +251,7 @@ class HLS: segment_key.put(newest_segment_key) if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment): - return + return 0 newest_init_data = init_data.get() if segment.init_section and (not newest_init_data or segment.discontinuity): @@ -302,6 +298,8 @@ class HLS: silent=True )) + data_size = len(newest_init_data or b"") + if isinstance(track, Audio) or newest_init_data: with open(segment_save_path, "rb+") as f: segment_data = f.read() @@ -318,6 +316,7 @@ class HLS: f.seek(0) f.write(newest_init_data) f.write(segment_data) + data_size += len(segment_data) if newest_segment_key[0]: newest_segment_key[0].decrypt(segment_save_path) @@ -325,6 +324,8 @@ class HLS: if callable(track.OnDecrypted): track.OnDecrypted(track) + return data_size + segment_key = Queue(maxsize=1) init_data = Queue(maxsize=1) range_offset = Queue(maxsize=1) @@ -344,36 +345,48 @@ class HLS: init_data.put(None) range_offset.put(0) - with tqdm(total=len(master.segments), unit="segments") as pbar: - with ThreadPoolExecutor(max_workers=16) as pool: - try: - for download in futures.as_completed(( - pool.submit( - download_segment, - filename=str(i).zfill(len(str(len(master.segments)))), - segment=segment, - init_data=init_data, - segment_key=segment_key, - range_offset=range_offset + progress(total=len(master.segments)) + + download_start_time = time.time() + download_sizes = [] + + with ThreadPoolExecutor(max_workers=16) as pool: + try: + for download in futures.as_completed(( + pool.submit( + download_segment, + filename=str(i).zfill(len(str(len(master.segments)))), + segment=segment, + init_data=init_data, + segment_key=segment_key + ) + for i, segment in enumerate(master.segments) + )): + if download.cancelled(): + continue + e = download.exception() + if e: + state_event.set() + pool.shutdown(wait=False, cancel_futures=True) + traceback.print_exception(e) + log.error(f"Segment Download worker threw an unhandled exception: {e!r}") + sys.exit(1) + else: + download_size = download.result() + elapsed_time = time.time() - download_start_time + download_sizes.append(download_size) + while elapsed_time - len(download_sizes) > 10: + download_sizes.pop(0) + download_speed = sum(download_sizes) / len(download_sizes) + progress( + advance=1, + downloaded=f"HLS {filesize.decimal(download_speed)}/s" ) - for i, segment in enumerate(master.segments) - )): - if download.cancelled(): - continue - e = download.exception() - if e: - state_event.set() - pool.shutdown(wait=False, cancel_futures=True) - traceback.print_exception(e) - log.error(f"Segment Download worker threw an unhandled exception: {e!r}") - sys.exit(1) - else: - pbar.update(1) - except KeyboardInterrupt: - state_event.set() - pool.shutdown(wait=False, cancel_futures=True) - console.log("Received Keyboard Interrupt, stopping...") - return + except KeyboardInterrupt: + state_event.set() + pool.shutdown(wait=False, cancel_futures=True) + console.log("Received Keyboard Interrupt, stopping...") + return @staticmethod def get_drm( diff --git a/devine/core/tracks/tracks.py b/devine/core/tracks/tracks.py index 548d1aa..743a252 100644 --- a/devine/core/tracks/tracks.py +++ b/devine/core/tracks/tracks.py @@ -2,14 +2,18 @@ from __future__ import annotations import logging import subprocess +from functools import partial from pathlib import Path from typing import Callable, Iterator, Optional, Sequence, Union from Cryptodome.Random import get_random_bytes from langcodes import Language, closest_supported_match +from rich.progress import Progress, TextColumn, SpinnerColumn, BarColumn, TimeRemainingColumn +from rich.table import Table from rich.tree import Tree from devine.core.config import config +from devine.core.console import console from devine.core.constants import LANGUAGE_MAX_DISTANCE, LANGUAGE_MUX_MAP, AnyTrack, TrackT from devine.core.tracks.audio import Audio from devine.core.tracks.chapter import Chapter @@ -87,9 +91,11 @@ class Tracks: return rep - def tree(self) -> Tree: + def tree(self, add_progress: bool = False) -> tuple[Tree, list[partial]]: all_tracks = [*list(self), *self.chapters] + progress_callables = [] + tree = Tree("", hide_root=True) for track_type in self.TRACK_ORDER_MAP: tracks = list(x for x in all_tracks if isinstance(x, track_type)) @@ -99,9 +105,28 @@ class Tracks: track_type_plural = track_type.__name__ + ("s" if track_type != Audio and num_tracks != 1 else "") tracks_tree = tree.add(f"[repr.number]{num_tracks}[/] {track_type_plural}") for track in tracks: - tracks_tree.add(str(track)[6:], style="text2") + if add_progress and track_type != Chapter: + progress = Progress( + TextColumn("[progress.description]{task.description}"), + SpinnerColumn(), + BarColumn(), + "•", + TimeRemainingColumn(compact=True, elapsed_when_finished=True), + "•", + TextColumn("[progress.data.speed]{task.fields[downloaded]}"), + console=console, + speed_estimate_period=10 + ) + task = progress.add_task("", downloaded="-") + progress_callables.append(partial(progress.update, task_id=task)) + track_table = Table.grid() + track_table.add_row(str(track)[6:], style="text2") + track_table.add_row(progress) + tracks_tree.add(track_table) + else: + tracks_tree.add(str(track)[6:], style="text2") - return tree + return tree, progress_callables def exists(self, by_id: Optional[str] = None, by_url: Optional[Union[str, list[str]]] = None) -> bool: """Check if a track already exists by various methods."""