forked from DRMTalks/devine
Replace tqdm progress bars with rich progress bars
This commit is contained in:
parent
cc69423374
commit
92895426b3
|
@ -27,7 +27,7 @@ from pymediainfo import MediaInfo
|
|||
from pywidevine.cdm import Cdm as WidevineCdm
|
||||
from pywidevine.device import Device
|
||||
from pywidevine.remotecdm import RemoteCdm
|
||||
from tqdm import tqdm
|
||||
from rich.live import Live
|
||||
from rich.padding import Padding
|
||||
from rich.rule import Rule
|
||||
|
||||
|
@ -332,7 +332,7 @@ class dl:
|
|||
title.tracks.sort_chapters()
|
||||
|
||||
if list_:
|
||||
available_tracks = title.tracks.tree()
|
||||
available_tracks, _ = title.tracks.tree()
|
||||
console.log(available_tracks)
|
||||
continue
|
||||
|
||||
|
@ -413,13 +413,19 @@ class dl:
|
|||
if not subs_only:
|
||||
title.tracks.subtitles.clear()
|
||||
|
||||
selected_tracks = title.tracks.tree()
|
||||
console.log(selected_tracks)
|
||||
selected_tracks, tracks_progress_callables = title.tracks.tree(add_progress=True)
|
||||
|
||||
if skip_dl:
|
||||
console.log("Skipping Download...")
|
||||
else:
|
||||
with tqdm(total=len(title.tracks)) as pbar:
|
||||
with Live(
|
||||
Padding(
|
||||
selected_tracks,
|
||||
(0, 5, 1, 5)
|
||||
),
|
||||
console=console,
|
||||
refresh_per_second=5
|
||||
):
|
||||
with ThreadPoolExecutor(workers) as pool:
|
||||
try:
|
||||
for download in futures.as_completed((
|
||||
|
@ -445,9 +451,10 @@ class dl:
|
|||
cdm_only=cdm_only,
|
||||
vaults_only=vaults_only,
|
||||
export=export
|
||||
)
|
||||
),
|
||||
progress=tracks_progress_callables[i]
|
||||
)
|
||||
for track in title.tracks
|
||||
for i, track in enumerate(title.tracks)
|
||||
)):
|
||||
if download.cancelled():
|
||||
continue
|
||||
|
@ -458,8 +465,6 @@ class dl:
|
|||
traceback.print_exception(type(e), e, e.__traceback__)
|
||||
self.log.error(f"Download worker threw an unhandled exception: {e!r}")
|
||||
return
|
||||
else:
|
||||
pbar.update(1)
|
||||
except KeyboardInterrupt:
|
||||
self.DL_POOL_STOP.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
|
@ -570,7 +575,8 @@ class dl:
|
|||
service: Service,
|
||||
track: AnyTrack,
|
||||
title: Title_T,
|
||||
prepare_drm: Callable
|
||||
prepare_drm: Callable,
|
||||
progress: partial
|
||||
):
|
||||
time.sleep(1)
|
||||
if self.DL_POOL_STOP.is_set():
|
||||
|
@ -581,8 +587,6 @@ class dl:
|
|||
else:
|
||||
proxy = None
|
||||
|
||||
console.log(f"Downloading: {track}")
|
||||
|
||||
if config.directories.temp.is_file():
|
||||
self.log.error(f"Temp Directory '{config.directories.temp}' must be a Directory, not a file")
|
||||
sys.exit(1)
|
||||
|
@ -612,6 +616,7 @@ class dl:
|
|||
HLS.download_track(
|
||||
track=track,
|
||||
save_dir=save_dir,
|
||||
progress=progress,
|
||||
session=service.session,
|
||||
proxy=proxy,
|
||||
license_widevine=prepare_drm
|
||||
|
@ -620,6 +625,7 @@ class dl:
|
|||
DASH.download_track(
|
||||
track=track,
|
||||
save_dir=save_dir,
|
||||
progress=progress,
|
||||
session=service.session,
|
||||
proxy=proxy,
|
||||
license_widevine=prepare_drm
|
||||
|
@ -631,7 +637,8 @@ class dl:
|
|||
track.url,
|
||||
save_path,
|
||||
service.session.headers,
|
||||
proxy if track.needs_proxy else None
|
||||
proxy if track.needs_proxy else None,
|
||||
progress=progress
|
||||
))
|
||||
track.path = save_path
|
||||
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import asyncio
|
||||
import subprocess
|
||||
import sys
|
||||
from asyncio import IncompleteReadError
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
|
@ -13,6 +16,7 @@ async def aria2c(
|
|||
headers: Optional[dict] = None,
|
||||
proxy: Optional[str] = None,
|
||||
silent: bool = False,
|
||||
progress: Optional[partial] = None,
|
||||
*args: str
|
||||
) -> int:
|
||||
"""
|
||||
|
@ -59,7 +63,7 @@ async def aria2c(
|
|||
"--summary-interval", "0",
|
||||
"--file-allocation", config.aria2c.get("file_allocation", "falloc"),
|
||||
"--console-log-level", "warn",
|
||||
"--download-result", "hide",
|
||||
"--download-result", ["hide", "default"][bool(progress)],
|
||||
*args,
|
||||
"-i", "-"
|
||||
]
|
||||
|
@ -84,9 +88,55 @@ async def aria2c(
|
|||
*arguments,
|
||||
stdin=subprocess.PIPE,
|
||||
stderr=[None, subprocess.DEVNULL][silent],
|
||||
stdout=[None, subprocess.DEVNULL][silent]
|
||||
stdout=(
|
||||
subprocess.PIPE if progress else
|
||||
subprocess.DEVNULL if silent else
|
||||
None
|
||||
)
|
||||
)
|
||||
await p.communicate(uri.encode())
|
||||
|
||||
p.stdin.write(uri.encode())
|
||||
await p.stdin.drain()
|
||||
p.stdin.close()
|
||||
|
||||
if progress:
|
||||
# I'm sorry for this shameful code, aria2(c) is annoying as f!!!
|
||||
buffer = b""
|
||||
recording = False
|
||||
while not p.stdout.at_eof():
|
||||
try:
|
||||
byte = await p.stdout.readexactly(1)
|
||||
except IncompleteReadError:
|
||||
pass # ignore, the first read will do this
|
||||
else:
|
||||
if byte == b"=": # download result log
|
||||
progress(total=100, completed=100)
|
||||
break
|
||||
if byte == b"[":
|
||||
recording = True
|
||||
if recording:
|
||||
buffer += byte
|
||||
if byte == b"]":
|
||||
recording = False
|
||||
if b"FileAlloc" not in buffer:
|
||||
try:
|
||||
# id, dledMiB/totalMiB(x%), CN:xx, DL:xxMiB, ETA:Xs
|
||||
# eta may not always be available
|
||||
parts = buffer.decode()[1:-1].split()
|
||||
dl_parts = parts[1].split("(")
|
||||
if len(dl_parts) == 2:
|
||||
# might otherwise be e.g., 0B/0B, with no % symbol provided
|
||||
progress(
|
||||
total=100,
|
||||
completed=int(dl_parts[1][:-2]),
|
||||
downloaded=f"{parts[3].split(':')[1]}/s"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Aria2c progress failed on {buffer}, {e!r}")
|
||||
sys.exit(1)
|
||||
buffer = b""
|
||||
|
||||
await p.wait()
|
||||
|
||||
if p.returncode != 0:
|
||||
raise subprocess.CalledProcessError(p.returncode, arguments)
|
||||
|
|
|
@ -11,6 +11,7 @@ import traceback
|
|||
from concurrent import futures
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import copy
|
||||
from functools import partial
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from threading import Event
|
||||
|
@ -23,7 +24,7 @@ from langcodes import Language, tag_is_valid
|
|||
from pywidevine.cdm import Cdm as WidevineCdm
|
||||
from pywidevine.pssh import PSSH
|
||||
from requests import Session
|
||||
from tqdm import tqdm
|
||||
from rich import filesize
|
||||
|
||||
from devine.core.console import console
|
||||
from devine.core.constants import AnyTrack
|
||||
|
@ -274,6 +275,7 @@ class DASH:
|
|||
def download_track(
|
||||
track: AnyTrack,
|
||||
save_dir: Path,
|
||||
progress: partial,
|
||||
session: Optional[Session] = None,
|
||||
proxy: Optional[str] = None,
|
||||
license_widevine: Optional[Callable] = None
|
||||
|
@ -447,10 +449,10 @@ class DASH:
|
|||
|
||||
state_event = Event()
|
||||
|
||||
def download_segment(filename: str, segment: tuple[str, Optional[str]]):
|
||||
def download_segment(filename: str, segment: tuple[str, Optional[str]]) -> int:
|
||||
time.sleep(0.1)
|
||||
if state_event.is_set():
|
||||
return
|
||||
return 0
|
||||
|
||||
segment_save_path = (save_dir / filename).with_suffix(".mp4")
|
||||
|
||||
|
@ -476,6 +478,8 @@ class DASH:
|
|||
silent=True
|
||||
))
|
||||
|
||||
data_size = len(init_data or b"")
|
||||
|
||||
if isinstance(track, Audio) or init_data:
|
||||
with open(segment_save_path, "rb+") as f:
|
||||
segment_data = f.read()
|
||||
|
@ -492,6 +496,7 @@ class DASH:
|
|||
f.seek(0)
|
||||
f.write(init_data)
|
||||
f.write(segment_data)
|
||||
data_size += len(segment_data)
|
||||
|
||||
if drm:
|
||||
# TODO: What if the manifest does not mention DRM, but has DRM
|
||||
|
@ -500,33 +505,48 @@ class DASH:
|
|||
if callable(track.OnDecrypted):
|
||||
track.OnDecrypted(track)
|
||||
|
||||
with tqdm(total=len(segments), unit="segments") as pbar:
|
||||
with ThreadPoolExecutor(max_workers=16) as pool:
|
||||
try:
|
||||
for download in futures.as_completed((
|
||||
pool.submit(
|
||||
download_segment,
|
||||
filename=str(i).zfill(len(str(len(segments)))),
|
||||
segment=segment
|
||||
return data_size
|
||||
|
||||
progress(total=len(segments))
|
||||
|
||||
download_start_time = time.time()
|
||||
download_sizes = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=16) as pool:
|
||||
try:
|
||||
for download in futures.as_completed((
|
||||
pool.submit(
|
||||
download_segment,
|
||||
filename=str(i).zfill(len(str(len(segments)))),
|
||||
segment=segment
|
||||
)
|
||||
for i, segment in enumerate(segments)
|
||||
)):
|
||||
if download.cancelled():
|
||||
continue
|
||||
e = download.exception()
|
||||
if e:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
traceback.print_exception(e)
|
||||
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
download_size = download.result()
|
||||
elapsed_time = time.time() - download_start_time
|
||||
download_sizes.append(download_size)
|
||||
while elapsed_time - len(download_sizes) > 10:
|
||||
download_sizes.pop(0)
|
||||
download_speed = sum(download_sizes) / len(download_sizes)
|
||||
progress(
|
||||
advance=1,
|
||||
downloaded=f"DASH {filesize.decimal(download_speed)}/s"
|
||||
)
|
||||
for i, segment in enumerate(segments)
|
||||
)):
|
||||
if download.cancelled():
|
||||
continue
|
||||
e = download.exception()
|
||||
if e:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
traceback.print_exception(e)
|
||||
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
pbar.update(1)
|
||||
except KeyboardInterrupt:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
console.log("Received Keyboard Interrupt, stopping...")
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
console.log("Received Keyboard Interrupt, stopping...")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_language(*options: Any) -> Optional[Language]:
|
||||
|
|
|
@ -8,6 +8,7 @@ import time
|
|||
import traceback
|
||||
from concurrent import futures
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from queue import Queue
|
||||
|
@ -21,7 +22,7 @@ from m3u8 import M3U8
|
|||
from pywidevine.cdm import Cdm as WidevineCdm
|
||||
from pywidevine.pssh import PSSH
|
||||
from requests import Session
|
||||
from tqdm import tqdm
|
||||
from rich import filesize
|
||||
|
||||
from devine.core.console import console
|
||||
from devine.core.constants import AnyTrack
|
||||
|
@ -183,6 +184,7 @@ class HLS:
|
|||
def download_track(
|
||||
track: AnyTrack,
|
||||
save_dir: Path,
|
||||
progress: partial,
|
||||
session: Optional[Session] = None,
|
||||
proxy: Optional[str] = None,
|
||||
license_widevine: Optional[Callable] = None
|
||||
|
@ -214,16 +216,10 @@ class HLS:
|
|||
|
||||
state_event = Event()
|
||||
|
||||
def download_segment(
|
||||
filename: str,
|
||||
segment: m3u8.Segment,
|
||||
init_data: Queue,
|
||||
segment_key: Queue,
|
||||
range_offset: Queue
|
||||
) -> None:
|
||||
def download_segment(filename: str, segment, init_data: Queue, segment_key: Queue) -> int:
|
||||
time.sleep(0.1)
|
||||
if state_event.is_set():
|
||||
return
|
||||
return 0
|
||||
|
||||
segment_save_path = (save_dir / filename).with_suffix(".mp4")
|
||||
|
||||
|
@ -255,7 +251,7 @@ class HLS:
|
|||
segment_key.put(newest_segment_key)
|
||||
|
||||
if callable(track.OnSegmentFilter) and track.OnSegmentFilter(segment):
|
||||
return
|
||||
return 0
|
||||
|
||||
newest_init_data = init_data.get()
|
||||
if segment.init_section and (not newest_init_data or segment.discontinuity):
|
||||
|
@ -302,6 +298,8 @@ class HLS:
|
|||
silent=True
|
||||
))
|
||||
|
||||
data_size = len(newest_init_data or b"")
|
||||
|
||||
if isinstance(track, Audio) or newest_init_data:
|
||||
with open(segment_save_path, "rb+") as f:
|
||||
segment_data = f.read()
|
||||
|
@ -318,6 +316,7 @@ class HLS:
|
|||
f.seek(0)
|
||||
f.write(newest_init_data)
|
||||
f.write(segment_data)
|
||||
data_size += len(segment_data)
|
||||
|
||||
if newest_segment_key[0]:
|
||||
newest_segment_key[0].decrypt(segment_save_path)
|
||||
|
@ -325,6 +324,8 @@ class HLS:
|
|||
if callable(track.OnDecrypted):
|
||||
track.OnDecrypted(track)
|
||||
|
||||
return data_size
|
||||
|
||||
segment_key = Queue(maxsize=1)
|
||||
init_data = Queue(maxsize=1)
|
||||
range_offset = Queue(maxsize=1)
|
||||
|
@ -344,36 +345,48 @@ class HLS:
|
|||
init_data.put(None)
|
||||
range_offset.put(0)
|
||||
|
||||
with tqdm(total=len(master.segments), unit="segments") as pbar:
|
||||
with ThreadPoolExecutor(max_workers=16) as pool:
|
||||
try:
|
||||
for download in futures.as_completed((
|
||||
pool.submit(
|
||||
download_segment,
|
||||
filename=str(i).zfill(len(str(len(master.segments)))),
|
||||
segment=segment,
|
||||
init_data=init_data,
|
||||
segment_key=segment_key,
|
||||
range_offset=range_offset
|
||||
progress(total=len(master.segments))
|
||||
|
||||
download_start_time = time.time()
|
||||
download_sizes = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=16) as pool:
|
||||
try:
|
||||
for download in futures.as_completed((
|
||||
pool.submit(
|
||||
download_segment,
|
||||
filename=str(i).zfill(len(str(len(master.segments)))),
|
||||
segment=segment,
|
||||
init_data=init_data,
|
||||
segment_key=segment_key
|
||||
)
|
||||
for i, segment in enumerate(master.segments)
|
||||
)):
|
||||
if download.cancelled():
|
||||
continue
|
||||
e = download.exception()
|
||||
if e:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
traceback.print_exception(e)
|
||||
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
download_size = download.result()
|
||||
elapsed_time = time.time() - download_start_time
|
||||
download_sizes.append(download_size)
|
||||
while elapsed_time - len(download_sizes) > 10:
|
||||
download_sizes.pop(0)
|
||||
download_speed = sum(download_sizes) / len(download_sizes)
|
||||
progress(
|
||||
advance=1,
|
||||
downloaded=f"HLS {filesize.decimal(download_speed)}/s"
|
||||
)
|
||||
for i, segment in enumerate(master.segments)
|
||||
)):
|
||||
if download.cancelled():
|
||||
continue
|
||||
e = download.exception()
|
||||
if e:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
traceback.print_exception(e)
|
||||
log.error(f"Segment Download worker threw an unhandled exception: {e!r}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
pbar.update(1)
|
||||
except KeyboardInterrupt:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
console.log("Received Keyboard Interrupt, stopping...")
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
state_event.set()
|
||||
pool.shutdown(wait=False, cancel_futures=True)
|
||||
console.log("Received Keyboard Interrupt, stopping...")
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_drm(
|
||||
|
|
|
@ -2,14 +2,18 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import subprocess
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, Optional, Sequence, Union
|
||||
|
||||
from Cryptodome.Random import get_random_bytes
|
||||
from langcodes import Language, closest_supported_match
|
||||
from rich.progress import Progress, TextColumn, SpinnerColumn, BarColumn, TimeRemainingColumn
|
||||
from rich.table import Table
|
||||
from rich.tree import Tree
|
||||
|
||||
from devine.core.config import config
|
||||
from devine.core.console import console
|
||||
from devine.core.constants import LANGUAGE_MAX_DISTANCE, LANGUAGE_MUX_MAP, AnyTrack, TrackT
|
||||
from devine.core.tracks.audio import Audio
|
||||
from devine.core.tracks.chapter import Chapter
|
||||
|
@ -87,9 +91,11 @@ class Tracks:
|
|||
|
||||
return rep
|
||||
|
||||
def tree(self) -> Tree:
|
||||
def tree(self, add_progress: bool = False) -> tuple[Tree, list[partial]]:
|
||||
all_tracks = [*list(self), *self.chapters]
|
||||
|
||||
progress_callables = []
|
||||
|
||||
tree = Tree("", hide_root=True)
|
||||
for track_type in self.TRACK_ORDER_MAP:
|
||||
tracks = list(x for x in all_tracks if isinstance(x, track_type))
|
||||
|
@ -99,9 +105,28 @@ class Tracks:
|
|||
track_type_plural = track_type.__name__ + ("s" if track_type != Audio and num_tracks != 1 else "")
|
||||
tracks_tree = tree.add(f"[repr.number]{num_tracks}[/] {track_type_plural}")
|
||||
for track in tracks:
|
||||
tracks_tree.add(str(track)[6:], style="text2")
|
||||
if add_progress and track_type != Chapter:
|
||||
progress = Progress(
|
||||
TextColumn("[progress.description]{task.description}"),
|
||||
SpinnerColumn(),
|
||||
BarColumn(),
|
||||
"•",
|
||||
TimeRemainingColumn(compact=True, elapsed_when_finished=True),
|
||||
"•",
|
||||
TextColumn("[progress.data.speed]{task.fields[downloaded]}"),
|
||||
console=console,
|
||||
speed_estimate_period=10
|
||||
)
|
||||
task = progress.add_task("", downloaded="-")
|
||||
progress_callables.append(partial(progress.update, task_id=task))
|
||||
track_table = Table.grid()
|
||||
track_table.add_row(str(track)[6:], style="text2")
|
||||
track_table.add_row(progress)
|
||||
tracks_tree.add(track_table)
|
||||
else:
|
||||
tracks_tree.add(str(track)[6:], style="text2")
|
||||
|
||||
return tree
|
||||
return tree, progress_callables
|
||||
|
||||
def exists(self, by_id: Optional[str] = None, by_url: Optional[Union[str, list[str]]] = None) -> bool:
|
||||
"""Check if a track already exists by various methods."""
|
||||
|
|
Loading…
Reference in New Issue