feat(Track): Allow Track to choose downloader to use

The downloader property must be a Callable of the same signature as the aria2c, curl_impersonate, and requests downloader functions. You can pass it these functions by importing, or a custom function of a matching signature.

Note: It will still override the chosen downloader and use a fallback one in the case of using aria2c downloader but the download uses the HTTP Range header.

Closes #70
This commit is contained in:
rlaphoenix 2024-03-08 16:48:44 +00:00
parent ba801739fe
commit 423ff289db
4 changed files with 27 additions and 26 deletions

View File

@ -1,13 +1,5 @@
from ..config import config
from .aria2c import aria2c
from .curl_impersonate import curl_impersonate
from .requests import requests
downloader = {
"aria2c": aria2c,
"curl_impersonate": curl_impersonate,
"requests": requests
}[config.downloader]
__all__ = ("downloader", "aria2c", "curl_impersonate", "requests")
__all__ = ("aria2c", "curl_impersonate", "requests")

View File

@ -22,7 +22,6 @@ from pywidevine.pssh import PSSH
from requests import Session
from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack
from devine.core.downloaders import downloader
from devine.core.downloaders import requests as requests_downloader
from devine.core.drm import Widevine
from devine.core.tracks import Audio, Subtitle, Tracks, Video
@ -452,12 +451,12 @@ class DASH:
progress(total=len(segments))
downloader_ = downloader
downloader = track.downloader
if downloader.__name__ == "aria2c" and any(bytes_range is not None for url, bytes_range in segments):
# aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader
downloader_ = requests_downloader
downloader = requests_downloader
for status_update in downloader_(
for status_update in downloader(
urls=[
{
"url": url,

View File

@ -20,7 +20,6 @@ from pywidevine.pssh import PSSH
from requests import Session
from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack
from devine.core.downloaders import downloader
from devine.core.downloaders import requests as requests_downloader
from devine.core.drm import DRM_T, ClearKey, Widevine
from devine.core.tracks import Audio, Subtitle, Tracks, Video
@ -247,7 +246,7 @@ class HLS:
total_segments = len(master.segments) - len(unwanted_segments)
progress(total=total_segments)
downloader_ = downloader
downloader = track.downloader
urls: list[dict[str, Any]] = []
range_offset = 0
@ -256,9 +255,9 @@ class HLS:
continue
if segment.byterange:
if downloader_.__name__ == "aria2c":
if downloader.__name__ == "aria2c":
# aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader
downloader_ = requests_downloader
downloader = requests_downloader
byte_range = HLS.calculate_byte_range(segment.byterange, range_offset)
range_offset = byte_range.split("-")[0]
else:
@ -273,7 +272,7 @@ class HLS:
segment_save_dir = save_dir / "segments"
for status_update in downloader_(
for status_update in downloader(
urls=urls,
output_dir=segment_save_dir,
filename="{i:0%d}{ext}" % len(str(len(urls))),

View File

@ -13,12 +13,12 @@ from uuid import UUID
from zlib import crc32
import m3u8
import requests
from langcodes import Language
from requests import Session
from devine.core.config import config
from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, TERRITORY_MAP
from devine.core.downloaders import downloader
from devine.core.downloaders import aria2c, curl_impersonate, requests
from devine.core.drm import DRM_T, Widevine
from devine.core.utilities import get_binary_path, get_boxes, try_ensure_utf8
from devine.core.utils.subprocess import ffprobe
@ -40,6 +40,7 @@ class Track:
name: Optional[str] = None,
drm: Optional[Iterable[DRM_T]] = None,
edition: Optional[str] = None,
downloader: Optional[Callable] = None,
data: Optional[dict] = None,
id_: Optional[str] = None,
) -> None:
@ -59,6 +60,8 @@ class Track:
raise TypeError(f"Expected id_ to be a {str}, not {type(id_)}")
if not isinstance(edition, (str, type(None))):
raise TypeError(f"Expected edition to be a {str}, not {type(edition)}")
if not isinstance(downloader, (Callable, type(None))):
raise TypeError(f"Expected downloader to be a {Callable}, not {type(downloader)}")
if not isinstance(data, (dict, type(None))):
raise TypeError(f"Expected data to be a {dict}, not {type(data)}")
@ -72,6 +75,13 @@ class Track:
except TypeError:
raise TypeError(f"Expected drm to be an iterable, not {type(drm)}")
if downloader is None:
downloader = {
"aria2c": aria2c,
"curl_impersonate": curl_impersonate,
"requests": requests
}[config.downloader]
self.path: Optional[Path] = None
self.url = url
self.language = Language.get(language)
@ -81,6 +91,7 @@ class Track:
self.name = name
self.drm = drm
self.edition: str = edition
self.downloader = downloader
self.data = data or {}
if not id_:
@ -116,7 +127,7 @@ class Track:
def download(
self,
session: requests.Session,
session: Session,
prepare_drm: partial,
progress: Optional[partial] = None
):
@ -213,7 +224,7 @@ class Track:
if DOWNLOAD_LICENCE_ONLY.is_set():
progress(downloaded="[yellow]SKIPPED")
else:
for status_update in downloader(
for status_update in self.downloader(
urls=self.url,
output_dir=save_path.parent,
filename=save_path.name,
@ -382,7 +393,7 @@ class Track:
maximum_size: int = 20000,
url: Optional[str] = None,
byte_range: Optional[str] = None,
session: Optional[requests.Session] = None
session: Optional[Session] = None
) -> bytes:
"""
Get the Track's Initial Segment Data Stream.
@ -412,8 +423,8 @@ class Track:
raise TypeError(f"Expected url to be a {str}, not {type(url)}")
if not isinstance(byte_range, (str, type(None))):
raise TypeError(f"Expected byte_range to be a {str}, not {type(byte_range)}")
if not isinstance(session, (requests.Session, type(None))):
raise TypeError(f"Expected session to be a {requests.Session}, not {type(session)}")
if not isinstance(session, (Session, type(None))):
raise TypeError(f"Expected session to be a {Session}, not {type(session)}")
if not url:
if self.descriptor != self.Descriptor.URL:
@ -423,7 +434,7 @@ class Track:
url = self.url
if not session:
session = requests.Session()
session = Session()
content_length = maximum_size