Replace tqdm progress bars with rich progress bars

This commit is contained in:
rlaphoenix 2023-02-25 13:45:17 +00:00
parent cc69423374
commit 92895426b3
5 changed files with 202 additions and 87 deletions

View File

@ -27,7 +27,7 @@ from pymediainfo import MediaInfo
from pywidevine.cdm import Cdm as WidevineCdm
from pywidevine.device import Device
from pywidevine.remotecdm import RemoteCdm
from tqdm import tqdm
from rich.live import Live
from rich.padding import Padding
from rich.rule import Rule
@ -332,7 +332,7 @@ class dl:
title.tracks.sort_chapters()
if list_:
available_tracks = title.tracks.tree()
available_tracks, _ = title.tracks.tree()
console.log(available_tracks)
continue
@ -413,13 +413,19 @@ class dl:
if not subs_only:
title.tracks.subtitles.clear()
selected_tracks = title.tracks.tree()
console.log(selected_tracks)
selected_tracks, tracks_progress_callables = title.tracks.tree(add_progress=True)
if skip_dl:
console.log("Skipping Download...")
else:
with tqdm(total=len(title.tracks)) as pbar:
with Live(
Padding(
selected_tracks,
(0, 5, 1, 5)
),
console=console,
refresh_per_second=5
):
with ThreadPoolExecutor(workers) as pool:
try:
for download in futures.as_completed((
@ -445,9 +451,10 @@ class dl:
cdm_only=cdm_only,
vaults_only=vaults_only,
export=export
)
),
progress=tracks_progress_callables[i]
)
for track in title.tracks
for i, track in enumerate(title.tracks)
)):
if download.cancelled():
continue
@ -458,8 +465,6 @@ class dl:
traceback.print_exception(type(e), e, e.__traceback__)
self.log.error(f"Download worker threw an unhandled exception: {e!r}")
return
else:
pbar.update(1)
except KeyboardInterrupt:
self.DL_POOL_STOP.set()
pool.shutdown(wait=False, cancel_futures=True)
@ -570,7 +575,8 @@ class dl:
service: Service,
track: AnyTrack,
title: Title_T,
prepare_drm: Callable
prepare_drm: Callable,
progress: partial
):
time.sleep(1)
if self.DL_POOL_STOP.is_set():
@ -581,8 +587,6 @@ class dl:
else:
proxy = None
console.log(f"Downloading: {track}")
if config.directories.temp.is_file():
self.log.error(f"Temp Directory '{config.directories.temp}' must be a Directory, not a file")
sys.exit(1)
@ -612,6 +616,7 @@ class dl:
HLS.download_track(
track=track,
save_dir=save_dir,
progress=progress,
session=service.session,
proxy=proxy,
license_widevine=prepare_drm
@ -620,6 +625,7 @@ class dl:
DASH.download_track(
track=track,
save_dir=save_dir,
progress=progress,
session=service.session,
proxy=proxy,
license_widevine=prepare_drm
@ -631,7 +637,8 @@ class dl:
track.url,
save_path,
service.session.headers,
proxy if track.needs_proxy else None
proxy if track.needs_proxy else None,
progress=progress
))
track.path = save_path

View File

@ -1,5 +1,8 @@
import asyncio
import subprocess
import sys
from asyncio import IncompleteReadError
from functools import partial
from pathlib import Path
from typing import Optional, Union
@ -13,6 +16,7 @@ async def aria2c(
headers: Optional[dict] = None,
proxy: Optional[str] = None,
silent: bool = False,
progress: Optional[partial] = None,
*args: str
) -> int:
"""
@ -59,7 +63,7 @@ async def aria2c(
"--summary-interval", "0",
"--file-allocation", config.aria2c.get("file_allocation", "falloc"),
"--console-log-level", "warn",
"--download-result", "hide",
"--download-result", ["hide", "default"][bool(progress)],
*args,
"-i", "-"
]
@ -84,9 +88,55 @@ async def aria2c(
*arguments,
stdin=subprocess.PIPE,
stderr=[None, subprocess.DEVNULL][silent],
stdout=[None, subprocess.DEVNULL][silent]
stdout=(
subprocess.PIPE if progress else
subprocess.DEVNULL if silent else
None
)
)
await p.communicate(uri.encode())
p.stdin.write(uri.encode())
await p.stdin.drain()
p.stdin.close()
if progress:
# I'm sorry for this shameful code, aria2(c) is annoying as f!!!
buffer = b""
recording = False
while not p.stdout.at_eof():
try:
byte = await p.stdout.readexactly(1)
except IncompleteReadError:
pass # ignore, the first read will do this
else:
if byte == b"=": # download result log
progress(total=100, completed=100)
break
if byte == b"[":
recording = True
if recording:
buffer += byte
if byte == b"]":
recording = False
if b"FileAlloc" not in buffer:
try:
# id, dledMiB/totalMiB(x%), CN:xx, DL:xxMiB, ETA:Xs
# eta may not always be available
parts = buffer.decode()[1:-1].split()
dl_parts = parts[1].split("(")
if len(dl_parts) == 2:
# might otherwise be e.g., 0B/0B, with no % symbol provided
progress(
total=100,
completed=int(dl_parts[1][:-2]),
downloaded=f"{parts[3].split(':')[1]}/s"
)
except Exception as e:
print(f"Aria2c progress failed on {buffer}, {e!r}")
sys.exit(1)
buffer = b""
await p.wait()
if p.returncode != 0:
raise subprocess.CalledProcessError(p.returncode, arguments)

View File

@ -11,6 +11,7 @@ import traceback
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from copy import copy
from functools import partial
from hashlib import md5
from pathlib import Path
from threading import Event
@ -23,7 +24,7 @@ from langcodes import Language, tag_is_valid
from pywidevine.cdm import Cdm as WidevineCdm
from pywidevine.pssh import PSSH
from requests import Session
from tqdm import tqdm
from rich import filesize
from devine.core.console import console
from devine.core.constants import AnyTrack
@ -274,6 +275,7 @@ class DASH:
def download_track(
track: AnyTrack,
save_dir: Path,
progress: partial,
session: Optional[Session] = None,
proxy: Optional[str] = None,
license_widevine: Optional[Callable] = None
@ -447,10 +449,10 @@ class DASH:
state_event = Event()
def download_segment(filename: str, segment: tuple[str, Optional[str]]):
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
time.sleep(0.1)
if state_event.is_set():
return
return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -476,6 +478,8 @@ class DASH:
silent=True
))
data_size = len(init_data or b"")
if isinstance(track, Audio) or init_data:
with open(segment_save_path, "rb+") as f:
segment_data = f.read()
@ -492,6 +496,7 @@ class DASH:
f.seek(0)
f.write(init_data)
f.write(segment_data)
data_size += len(segment_data)
if drm:
# TODO: What if the manifest does not mention DRM, but has DRM
@ -500,33 +505,48 @@ class DASH:
if callable(track.OnDecrypted):
track.OnDecrypted(track)
with tqdm(total=len(segments), unit="segments") as pbar:
with ThreadPoolExecutor(max_workers=16) as pool:
try:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(segments)))),
segment=segment
return data_size
progress(total=len(segments))
download_start_time = time.time()
download_sizes = []
with ThreadPoolExecutor(max_workers=16) as pool:
try:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(segments)))),
segment=segment
)
for i, segment in enumerate(segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
download_size = download.result()
elapsed_time = time.time() - download_start_time
download_sizes.append(download_size)
while elapsed_time - len(download_sizes) > 10:
download_sizes.pop(0)
download_speed = sum(download_sizes) / len(download_sizes)
progress(
advance=1,
downloaded=f"DASH {filesize.decimal(download_speed)}/s"
)
for i, segment in enumerate(segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
pbar.update(1)
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
console.log("Received Keyboard Interrupt, stopping...")
return
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
console.log("Received Keyboard Interrupt, stopping...")
return
@staticmethod
def get_language(*options: Any) -> Optional[Language]:

View File

@ -8,6 +8,7 @@ import time
import traceback
from concurrent import futures
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from hashlib import md5
from pathlib import Path
from queue import Queue
@ -21,7 +22,7 @@ from m3u8 import M3U8
from pywidevine.cdm import Cdm as WidevineCdm
from pywidevine.pssh import PSSH
from requests import Session
from tqdm import tqdm
from rich import filesize
from devine.core.console import console
from devine.core.constants import AnyTrack
@ -183,6 +184,7 @@ class HLS:
def download_track(
track: AnyTrack,
save_dir: Path,
progress: partial,
session: Optional[Session] = None,
proxy: Optional[str] = None,
license_widevine: Optional[Callable] = None
@ -214,16 +216,10 @@ class HLS:
state_event = Event()
def download_segment(
filename: str,
segment: m3u8.Segment,
init_data: Queue,
segment_key: Queue,
range_offset: Queue
) -> None:
def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue) -> int:
time.sleep(0.1)
if state_event.is_set():
return
return 0
segment_save_path = (save_dir / filename).with_suffix(".mp4")
@ -255,7 +251,7 @@ class HLS:
segment_key.put(newest_segment_key)
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
return
return 0
newest_init_data = init_data.get()
if segment.init_section and (not newest_init_data or segment.discontinuity):
@ -302,6 +298,8 @@ class HLS:
silent=True
))
data_size = len(newest_init_data or b"")
if isinstance(track, Audio) or newest_init_data:
with open(segment_save_path, "rb+") as f:
segment_data = f.read()
@ -318,6 +316,7 @@ class HLS:
f.seek(0)
f.write(newest_init_data)
f.write(segment_data)
data_size += len(segment_data)
if newest_segment_key[0]:
newest_segment_key[0].decrypt(segment_save_path)
@ -325,6 +324,8 @@ class HLS:
if callable(track.OnDecrypted):
track.OnDecrypted(track)
return data_size
segment_key = Queue(maxsize=1)
init_data = Queue(maxsize=1)
range_offset = Queue(maxsize=1)
@ -344,36 +345,48 @@ class HLS:
init_data.put(None)
range_offset.put(0)
with tqdm(total=len(master.segments), unit="segments") as pbar:
with ThreadPoolExecutor(max_workers=16) as pool:
try:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(master.segments)))),
segment=segment,
init_data=init_data,
segment_key=segment_key,
range_offset=range_offset
progress(total=len(master.segments))
download_start_time = time.time()
download_sizes = []
with ThreadPoolExecutor(max_workers=16) as pool:
try:
for download in futures.as_completed((
pool.submit(
download_segment,
filename=str(i).zfill(len(str(len(master.segments)))),
segment=segment,
init_data=init_data,
segment_key=segment_key
)
for i, segment in enumerate(master.segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
download_size = download.result()
elapsed_time = time.time() - download_start_time
download_sizes.append(download_size)
while elapsed_time - len(download_sizes) > 10:
download_sizes.pop(0)
download_speed = sum(download_sizes) / len(download_sizes)
progress(
advance=1,
downloaded=f"HLS {filesize.decimal(download_speed)}/s"
)
for i, segment in enumerate(master.segments)
)):
if download.cancelled():
continue
e = download.exception()
if e:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
traceback.print_exception(e)
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
sys.exit(1)
else:
pbar.update(1)
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
console.log("Received Keyboard Interrupt, stopping...")
return
except KeyboardInterrupt:
state_event.set()
pool.shutdown(wait=False, cancel_futures=True)
console.log("Received Keyboard Interrupt, stopping...")
return
@staticmethod
def get_drm(

View File

@ -2,14 +2,18 @@ from __future__ import annotations
import logging
import subprocess
from functools import partial
from pathlib import Path
from typing import Callable, Iterator, Optional, Sequence, Union
from Cryptodome.Random import get_random_bytes
from langcodes import Language, closest_supported_match
from rich.progress import Progress, TextColumn, SpinnerColumn, BarColumn, TimeRemainingColumn
from rich.table import Table
from rich.tree import Tree
from devine.core.config import config
from devine.core.console import console
from devine.core.constants import LANGUAGE_MAX_DISTANCE, LANGUAGE_MUX_MAP, AnyTrack, TrackT
from devine.core.tracks.audio import Audio
from devine.core.tracks.chapter import Chapter
@ -87,9 +91,11 @@ class Tracks:
return rep
def tree(self) -> Tree:
def tree(self, add_progress: bool = False) -> tuple[Tree, list[partial]]:
all_tracks = [*list(self), *self.chapters]
progress_callables = []
tree = Tree("", hide_root=True)
for track_type in self.TRACK_ORDER_MAP:
tracks = list(x for x in all_tracks if isinstance(x, track_type))
@ -99,9 +105,28 @@ class Tracks:
track_type_plural = track_type.__name__ + ("s" if track_type != Audio and num_tracks != 1 else "")
tracks_tree = tree.add(f"[repr.number]{num_tracks}[/] {track_type_plural}")
for track in tracks:
tracks_tree.add(str(track)[6:], style="text2")
if add_progress and track_type != Chapter:
progress = Progress(
TextColumn("[progress.description]{task.description}"),
SpinnerColumn(),
BarColumn(),
"",
TimeRemainingColumn(compact=True, elapsed_when_finished=True),
"",
TextColumn("[progress.data.speed]{task.fields[downloaded]}"),
console=console,
speed_estimate_period=10
)
task = progress.add_task("", downloaded="-")
progress_callables.append(partial(progress.update, task_id=task))
track_table = Table.grid()
track_table.add_row(str(track)[6:], style="text2")
track_table.add_row(progress)
tracks_tree.add(track_table)
else:
tracks_tree.add(str(track)[6:], style="text2")
return tree
return tree, progress_callables
def exists(self, by_id: Optional[str] = None, by_url: Optional[Union[str, list[str]]] = None) -> bool:
"""Check if a track already exists by various methods."""