Replace tqdm progress bars with rich progress bars

This commit is contained in:
rlaphoenix 2023-02-25 13:45:17 +00:00
parent cc69423374
commit 92895426b3
5 changed files with 202 additions and 87 deletions

View File

@ -27,7 +27,7 @@ from pymediainfo import MediaInfo
from pywidevine.cdm import Cdm as WidevineCdm from pywidevine.cdm import Cdm as WidevineCdm
from pywidevine.device import Device from pywidevine.device import Device
from pywidevine.remotecdm import RemoteCdm from pywidevine.remotecdm import RemoteCdm
from tqdm import tqdm from rich.live import Live
from rich.padding import Padding from rich.padding import Padding
from rich.rule import Rule from rich.rule import Rule
@ -332,7 +332,7 @@ class dl:
title.tracks.sort_chapters() title.tracks.sort_chapters()
if list_: if list_:
available_tracks = title.tracks.tree() available_tracks, _ = title.tracks.tree()
console.log(available_tracks) console.log(available_tracks)
continue continue
@ -413,13 +413,19 @@ class dl:
if not subs_only: if not subs_only:
title.tracks.subtitles.clear() title.tracks.subtitles.clear()
selected_tracks = title.tracks.tree() selected_tracks, tracks_progress_callables = title.tracks.tree(add_progress=True)
console.log(selected_tracks)
if skip_dl: if skip_dl:
console.log("Skipping Download...") console.log("Skipping Download...")
else: else:
with tqdm(total=len(title.tracks)) as pbar: with Live(
Padding(
selected_tracks,
(0, 5, 1, 5)
),
console=console,
refresh_per_second=5
):
with ThreadPoolExecutor(workers) as pool: with ThreadPoolExecutor(workers) as pool:
try: try:
for download in futures.as_completed(( for download in futures.as_completed((
@ -445,9 +451,10 @@ class dl:
cdm_only=cdm_only, cdm_only=cdm_only,
vaults_only=vaults_only, vaults_only=vaults_only,
export=export export=export
) ),
progress=tracks_progress_callables[i]
) )
for track in title.tracks for i, track in enumerate(title.tracks)
)): )):
if download.cancelled(): if download.cancelled():
continue continue
@ -458,8 +465,6 @@ class dl:
traceback.print_exception(type(e), e, e.__traceback__) traceback.print_exception(type(e), e, e.__traceback__)
self.log.error(f"Download worker threw an unhandled exception: {e!r}") self.log.error(f"Download worker threw an unhandled exception: {e!r}")
return return
else:
pbar.update(1)
except KeyboardInterrupt: except KeyboardInterrupt:
self.DL_POOL_STOP.set() self.DL_POOL_STOP.set()
pool.shutdown(wait=False, cancel_futures=True) pool.shutdown(wait=False, cancel_futures=True)
@ -570,7 +575,8 @@ class dl:
service: Service, service: Service,
track: AnyTrack, track: AnyTrack,
title: Title_T, title: Title_T,
prepare_drm: Callable prepare_drm: Callable,
progress: partial
): ):
time.sleep(1) time.sleep(1)
if self.DL_POOL_STOP.is_set(): if self.DL_POOL_STOP.is_set():
@ -581,8 +587,6 @@ class dl:
else: else:
proxy = None proxy = None
console.log(f"Downloading: {track}")
if config.directories.temp.is_file(): if config.directories.temp.is_file():
self.log.error(f"Temp Directory '{config.directories.temp}' must be a Directory, not a file") self.log.error(f"Temp Directory '{config.directories.temp}' must be a Directory, not a file")
sys.exit(1) sys.exit(1)
@ -612,6 +616,7 @@ class dl:
HLS.download_track( HLS.download_track(
track=track, track=track,
save_dir=save_dir, save_dir=save_dir,
progress=progress,
session=service.session, session=service.session,
proxy=proxy, proxy=proxy,
license_widevine=prepare_drm license_widevine=prepare_drm
@ -620,6 +625,7 @@ class dl:
DASH.download_track( DASH.download_track(
track=track, track=track,
save_dir=save_dir, save_dir=save_dir,
progress=progress,
session=service.session, session=service.session,
proxy=proxy, proxy=proxy,
license_widevine=prepare_drm license_widevine=prepare_drm
@ -631,7 +637,8 @@ class dl:
track.url, track.url,
save_path, save_path,
service.session.headers, service.session.headers,
proxy if track.needs_proxy else None proxy if track.needs_proxy else None,
progress=progress
)) ))
track.path = save_path track.path = save_path

View File

@ -1,5 +1,8 @@
import asyncio import asyncio
import subprocess import subprocess
import sys
from asyncio import IncompleteReadError
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
@ -13,6 +16,7 @@ async def aria2c(
headers: Optional[dict] = None, headers: Optional[dict] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
silent: bool = False, silent: bool = False,
progress: Optional[partial] = None,
*args: str *args: str
) -> int: ) -> int:
""" """
@ -59,7 +63,7 @@ async def aria2c(
"--summary-interval", "0", "--summary-interval", "0",
"--file-allocation", config.aria2c.get("file_allocation", "falloc"), "--file-allocation", config.aria2c.get("file_allocation", "falloc"),
"--console-log-level", "warn", "--console-log-level", "warn",
"--download-result", "hide", "--download-result", ["hide", "default"][bool(progress)],
*args, *args,
"-i", "-" "-i", "-"
] ]
@ -84,9 +88,55 @@ async def aria2c(
*arguments, *arguments,
stdin=subprocess.PIPE, stdin=subprocess.PIPE,
stderr=[None, subprocess.DEVNULL][silent], stderr=[None, subprocess.DEVNULL][silent],
stdout=[None, subprocess.DEVNULL][silent] stdout=(
subprocess.PIPE if progress else
subprocess.DEVNULL if silent else
None
)
) )
await p.communicate(uri.encode())
p.stdin.write(uri.encode())
await p.stdin.drain()
p.stdin.close()
if progress:
# I'm sorry for this shameful code, aria2(c) is annoying as f!!!
buffer = b""
recording = False
while not p.stdout.at_eof():
try:
byte = await p.stdout.readexactly(1)
except IncompleteReadError:
pass # ignore, the first read will do this
else:
if byte == b"=": # download result log
progress(total=100, completed=100)
break
if byte == b"[":
recording = True
if recording:
buffer += byte
if byte == b"]":
recording = False
if b"FileAlloc" not in buffer:
try:
# id, dledMiB/totalMiB(x%), CN:xx, DL:xxMiB, ETA:Xs
# eta may not always be available
parts = buffer.decode()[1:-1].split()
dl_parts = parts[1].split("(")
if len(dl_parts) == 2:
# might otherwise be e.g., 0B/0B, with no % symbol provided
progress(
total=100,
completed=int(dl_parts[1][:-2]),
downloaded=f"{parts[3].split(':')[1]}/s"
)
except Exception as e:
print(f"Aria2c progress failed on {buffer}, {e!r}")
sys.exit(1)
buffer = b""
await p.wait()
if p.returncode != 0: if p.returncode != 0:
raise subprocess.CalledProcessError(p.returncode, arguments) raise subprocess.CalledProcessError(p.returncode, arguments)

View File

@ -11,6 +11,7 @@ import traceback
from concurrent import futures from concurrent import futures
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from copy import copy from copy import copy
from functools import partial
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from threading import Event from threading import Event
@ -23,7 +24,7 @@ from langcodes import Language, tag_is_valid
from pywidevine.cdm import Cdm as WidevineCdm from pywidevine.cdm import Cdm as WidevineCdm
from pywidevine.pssh import PSSH from pywidevine.pssh import PSSH
from requests import Session from requests import Session
from tqdm import tqdm from rich import filesize
from devine.core.console import console from devine.core.console import console
from devine.core.constants import AnyTrack from devine.core.constants import AnyTrack
@ -274,6 +275,7 @@ class DASH:
def download_track( def download_track(
track: AnyTrack, track: AnyTrack,
save_dir: Path, save_dir: Path,
progress: partial,
session: Optional[Session] = None, session: Optional[Session] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
license_widevine: Optional[Callable] = None license_widevine: Optional[Callable] = None
@ -447,10 +449,10 @@ class DASH:
state_event = Event() state_event = Event()
def download_segment(filename: str, segment: tuple[str, Optional[str]]): def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
time.sleep(0.1) time.sleep(0.1)
if state_event.is_set(): if state_event.is_set():
return return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4") segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -476,6 +478,8 @@ class DASH:
silent=True silent=True
)) ))
data_size = len(init_data or b"")
if isinstance(track, Audio) or init_data: if isinstance(track, Audio) or init_data:
with open(segment_save_path, "rb+") as f: with open(segment_save_path, "rb+") as f:
segment_data = f.read() segment_data = f.read()
@ -492,6 +496,7 @@ class DASH:
f.seek(0) f.seek(0)
f.write(init_data) f.write(init_data)
f.write(segment_data) f.write(segment_data)
data_size += len(segment_data)
if drm: if drm:
# TODO: What if the manifest does not mention DRM, but has DRM # TODO: What if the manifest does not mention DRM, but has DRM
@ -500,33 +505,48 @@ class DASH:
if callable(track.OnDecrypted): if callable(track.OnDecrypted):
track.OnDecrypted(track) track.OnDecrypted(track)
with tqdm(total=len(segments), unit="segments") as pbar: return data_size
with ThreadPoolExecutor(max_workers=16) as pool:
try: progress(total=len(segments))
for download in futures.as_completed((
pool.submit( download_start_time = time.time()
download_segment, download_sizes = []
filename=str(i).zfill(len(str(len(segments)))),
segment=segment with ThreadPoolExecutor(max_workers=16) as pool:
try:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(segments)))),
segment=segment
)
for i, segment in enumerate(segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
download_size = download.result()
elapsed_time = time.time() - download_start_time
download_sizes.append(download_size)
while elapsed_time - len(download_sizes) > 10:
download_sizes.pop(0)
download_speed = sum(download_sizes) / len(download_sizes)
progress(
advance=1,
downloaded=f"DASH {filesize.decimal(download_speed)}/s"
) )
for i, segment in enumerate(segments) except KeyboardInterrupt:
)): state_event.set()
if download.cancelled(): pool.shutdown(wait=False, cancel_futures=True)
continue console.log("Received Keyboard Interrupt, stopping...")
e = download.exception() return
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
pbar.update(1)
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
console.log("Received Keyboard Interrupt, stopping...")
return
@staticmethod @staticmethod
def get_language(*options: Any) -> Optional[Language]: def get_language(*options: Any) -> Optional[Language]:

View File

@ -8,6 +8,7 @@ import time
import traceback import traceback
from concurrent import futures from concurrent import futures
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from functools import partial
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from queue import Queue from queue import Queue
@ -21,7 +22,7 @@ from m3u8 import M3U8
from pywidevine.cdm import Cdm as WidevineCdm from pywidevine.cdm import Cdm as WidevineCdm
from pywidevine.pssh import PSSH from pywidevine.pssh import PSSH
from requests import Session from requests import Session
from tqdm import tqdm from rich import filesize
from devine.core.console import console from devine.core.console import console
from devine.core.constants import AnyTrack from devine.core.constants import AnyTrack
@ -183,6 +184,7 @@ class HLS:
def download_track( def download_track(
track: AnyTrack, track: AnyTrack,
save_dir: Path, save_dir: Path,
progress: partial,
session: Optional[Session] = None, session: Optional[Session] = None,
proxy: Optional[str] = None, proxy: Optional[str] = None,
license_widevine: Optional[Callable] = None license_widevine: Optional[Callable] = None
@ -214,16 +216,10 @@ class HLS:
state_event = Event() state_event = Event()
def download_segment( def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue) -> int:
filename: str,
segment: m3u8.Segment,
init_data: Queue,
segment_key: Queue,
range_offset: Queue
) -> None:
time.sleep(0.1) time.sleep(0.1)
if state_event.is_set(): if state_event.is_set():
return return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4") segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -255,7 +251,7 @@ class HLS:
segment_key.put(newest_segment_key) segment_key.put(newest_segment_key)
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment): if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
return return 0
newest_init_data = init_data.get() newest_init_data = init_data.get()
if segment.init_section and (not newest_init_data or segment.discontinuity): if segment.init_section and (not newest_init_data or segment.discontinuity):
@ -302,6 +298,8 @@ class HLS:
silent=True silent=True
)) ))
data_size = len(newest_init_data or b"")
if isinstance(track, Audio) or newest_init_data: if isinstance(track, Audio) or newest_init_data:
with open(segment_save_path, "rb+") as f: with open(segment_save_path, "rb+") as f:
segment_data = f.read() segment_data = f.read()
@ -318,6 +316,7 @@ class HLS:
f.seek(0) f.seek(0)
f.write(newest_init_data) f.write(newest_init_data)
f.write(segment_data) f.write(segment_data)
data_size += len(segment_data)
if newest_segment_key[0]: if newest_segment_key[0]:
newest_segment_key[0].decrypt(segment_save_path) newest_segment_key[0].decrypt(segment_save_path)
@ -325,6 +324,8 @@ class HLS:
if callable(track.OnDecrypted): if callable(track.OnDecrypted):
track.OnDecrypted(track) track.OnDecrypted(track)
return data_size
segment_key = Queue(maxsize=1) segment_key = Queue(maxsize=1)
init_data = Queue(maxsize=1) init_data = Queue(maxsize=1)
range_offset = Queue(maxsize=1) range_offset = Queue(maxsize=1)
@ -344,36 +345,48 @@ class HLS:
init_data.put(None) init_data.put(None)
range_offset.put(0) range_offset.put(0)
with tqdm(total=len(master.segments), unit="segments") as pbar: progress(total=len(master.segments))
with ThreadPoolExecutor(max_workers=16) as pool:
try: download_start_time = time.time()
for download in futures.as_completed(( download_sizes = []
pool.submit(
download_segment, with ThreadPoolExecutor(max_workers=16) as pool:
filename=str(i).zfill(len(str(len(master.segments)))), try:
segment=segment, for download in futures.as_completed((
init_data=init_data, pool.submit(
segment_key=segment_key, download_segment,
range_offset=range_offset filename=str(i).zfill(len(str(len(master.segments)))),
segment=segment,
init_data=init_data,
segment_key=segment_key
)
for i, segment in enumerate(master.segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
download_size = download.result()
elapsed_time = time.time() - download_start_time
download_sizes.append(download_size)
while elapsed_time - len(download_sizes) > 10:
download_sizes.pop(0)
download_speed = sum(download_sizes) / len(download_sizes)
progress(
advance=1,
downloaded=f"HLS {filesize.decimal(download_speed)}/s"
) )
for i, segment in enumerate(master.segments) except KeyboardInterrupt:
)): state_event.set()
if download.cancelled(): pool.shutdown(wait=False, cancel_futures=True)
continue console.log("Received Keyboard Interrupt, stopping...")
e = download.exception() return
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
pbar.update(1)
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
console.log("Received Keyboard Interrupt, stopping...")
return
@staticmethod @staticmethod
def get_drm( def get_drm(

View File

@ -2,14 +2,18 @@ from __future__ import annotations
import logging import logging
import subprocess import subprocess
from functools import partial
from pathlib import Path from pathlib import Path
from typing import Callable, Iterator, Optional, Sequence, Union from typing import Callable, Iterator, Optional, Sequence, Union
from Cryptodome.Random import get_random_bytes from Cryptodome.Random import get_random_bytes
from langcodes import Language, closest_supported_match from langcodes import Language, closest_supported_match
from rich.progress import Progress, TextColumn, SpinnerColumn, BarColumn, TimeRemainingColumn
from rich.table import Table
from rich.tree import Tree from rich.tree import Tree
from devine.core.config import config from devine.core.config import config
from devine.core.console import console
from devine.core.constants import LANGUAGE_MAX_DISTANCE, LANGUAGE_MUX_MAP, AnyTrack, TrackT from devine.core.constants import LANGUAGE_MAX_DISTANCE, LANGUAGE_MUX_MAP, AnyTrack, TrackT
from devine.core.tracks.audio import Audio from devine.core.tracks.audio import Audio
from devine.core.tracks.chapter import Chapter from devine.core.tracks.chapter import Chapter
@ -87,9 +91,11 @@ class Tracks:
return rep return rep
def tree(self) -> Tree: def tree(self, add_progress: bool = False) -> tuple[Tree, list[partial]]:
all_tracks = [*list(self), *self.chapters] all_tracks = [*list(self), *self.chapters]
progress_callables = []
tree = Tree("", hide_root=True) tree = Tree("", hide_root=True)
for track_type in self.TRACK_ORDER_MAP: for track_type in self.TRACK_ORDER_MAP:
tracks = list(x for x in all_tracks if isinstance(x, track_type)) tracks = list(x for x in all_tracks if isinstance(x, track_type))
@ -99,9 +105,28 @@ class Tracks:
track_type_plural = track_type.__name__ + ("s" if track_type != Audio and num_tracks != 1 else "") track_type_plural = track_type.__name__ + ("s" if track_type != Audio and num_tracks != 1 else "")
tracks_tree = tree.add(f"[repr.number]{num_tracks}[/] {track_type_plural}") tracks_tree = tree.add(f"[repr.number]{num_tracks}[/] {track_type_plural}")
for track in tracks: for track in tracks:
tracks_tree.add(str(track)[6:], style="text2") if add_progress and track_type != Chapter:
progress = Progress(
TextColumn("[progress.description]{task.description}"),
SpinnerColumn(),
BarColumn(),
"",
TimeRemainingColumn(compact=True, elapsed_when_finished=True),
"",
TextColumn("[progress.data.speed]{task.fields[downloaded]}"),
console=console,
speed_estimate_period=10
)
task = progress.add_task("", downloaded="-")
progress_callables.append(partial(progress.update, task_id=task))
track_table = Table.grid()
track_table.add_row(str(track)[6:], style="text2")
track_table.add_row(progress)
tracks_tree.add(track_table)
else:
tracks_tree.add(str(track)[6:], style="text2")
return tree return tree, progress_callables
def exists(self, by_id: Optional[str] = None, by_url: Optional[Union[str, list[str]]] = None) -> bool: def exists(self, by_id: Optional[str] = None, by_url: Optional[Union[str, list[str]]] = None) -> bool:
"""Check if a track already exists by various methods.""" """Check if a track already exists by various methods."""