diff --git a/devine/core/downloaders/__init__.py b/devine/core/downloaders/__init__.py index 7dfb36c..d59e55e 100644 --- a/devine/core/downloaders/__init__.py +++ b/devine/core/downloaders/__init__.py @@ -1,13 +1,5 @@ -from ..config import config from .aria2c import aria2c from .curl_impersonate import curl_impersonate from .requests import requests -downloader = { - "aria2c": aria2c, - "curl_impersonate": curl_impersonate, - "requests": requests -}[config.downloader] - - -__all__ = ("downloader", "aria2c", "curl_impersonate", "requests") +__all__ = ("aria2c", "curl_impersonate", "requests") diff --git a/devine/core/manifests/dash.py b/devine/core/manifests/dash.py index 6580e14..00db1e0 100644 --- a/devine/core/manifests/dash.py +++ b/devine/core/manifests/dash.py @@ -22,7 +22,6 @@ from pywidevine.pssh import PSSH from requests import Session from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack -from devine.core.downloaders import downloader from devine.core.downloaders import requests as requests_downloader from devine.core.drm import Widevine from devine.core.tracks import Audio, Subtitle, Tracks, Video @@ -452,12 +451,12 @@ class DASH: progress(total=len(segments)) - downloader_ = downloader + downloader = track.downloader if downloader.__name__ == "aria2c" and any(bytes_range is not None for url, bytes_range in segments): # aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader - downloader_ = requests_downloader + downloader = requests_downloader - for status_update in downloader_( + for status_update in downloader( urls=[ { "url": url, diff --git a/devine/core/manifests/hls.py b/devine/core/manifests/hls.py index 7f47e39..60046e8 100644 --- a/devine/core/manifests/hls.py +++ b/devine/core/manifests/hls.py @@ -20,7 +20,6 @@ from pywidevine.pssh import PSSH from requests import Session from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack -from devine.core.downloaders import downloader from devine.core.downloaders import requests as requests_downloader from devine.core.drm import DRM_T, ClearKey, Widevine from devine.core.tracks import Audio, Subtitle, Tracks, Video @@ -247,7 +246,7 @@ class HLS: total_segments = len(master.segments) - len(unwanted_segments) progress(total=total_segments) - downloader_ = downloader + downloader = track.downloader urls: list[dict[str, Any]] = [] range_offset = 0 @@ -256,9 +255,9 @@ class HLS: continue if segment.byterange: - if downloader_.__name__ == "aria2c": + if downloader.__name__ == "aria2c": # aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader - downloader_ = requests_downloader + downloader = requests_downloader byte_range = HLS.calculate_byte_range(segment.byterange, range_offset) range_offset = byte_range.split("-")[0] else: @@ -273,7 +272,7 @@ class HLS: segment_save_dir = save_dir / "segments" - for status_update in downloader_( + for status_update in downloader( urls=urls, output_dir=segment_save_dir, filename="{i:0%d}{ext}" % len(str(len(urls))), diff --git a/devine/core/tracks/track.py b/devine/core/tracks/track.py index 219cce6..db5d445 100644 --- a/devine/core/tracks/track.py +++ b/devine/core/tracks/track.py @@ -13,12 +13,12 @@ from uuid import UUID from zlib import crc32 import m3u8 -import requests from langcodes import Language +from requests import Session from devine.core.config import config from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, TERRITORY_MAP -from devine.core.downloaders import downloader +from devine.core.downloaders import aria2c, curl_impersonate, requests from devine.core.drm import DRM_T, Widevine from devine.core.utilities import get_binary_path, get_boxes, try_ensure_utf8 from devine.core.utils.subprocess import ffprobe @@ -40,6 +40,7 @@ class Track: name: Optional[str] = None, drm: Optional[Iterable[DRM_T]] = None, edition: Optional[str] = None, + downloader: Optional[Callable] = None, data: Optional[dict] = None, id_: Optional[str] = None, ) -> None: @@ -59,6 +60,8 @@ class Track: raise TypeError(f"Expected id_ to be a {str}, not {type(id_)}") if not isinstance(edition, (str, type(None))): raise TypeError(f"Expected edition to be a {str}, not {type(edition)}") + if not isinstance(downloader, (Callable, type(None))): + raise TypeError(f"Expected downloader to be a {Callable}, not {type(downloader)}") if not isinstance(data, (dict, type(None))): raise TypeError(f"Expected data to be a {dict}, not {type(data)}") @@ -72,6 +75,13 @@ class Track: except TypeError: raise TypeError(f"Expected drm to be an iterable, not {type(drm)}") + if downloader is None: + downloader = { + "aria2c": aria2c, + "curl_impersonate": curl_impersonate, + "requests": requests + }[config.downloader] + self.path: Optional[Path] = None self.url = url self.language = Language.get(language) @@ -81,6 +91,7 @@ class Track: self.name = name self.drm = drm self.edition: str = edition + self.downloader = downloader self.data = data or {} if not id_: @@ -116,7 +127,7 @@ class Track: def download( self, - session: requests.Session, + session: Session, prepare_drm: partial, progress: Optional[partial] = None ): @@ -213,7 +224,7 @@ class Track: if DOWNLOAD_LICENCE_ONLY.is_set(): progress(downloaded="[yellow]SKIPPED") else: - for status_update in downloader( + for status_update in self.downloader( urls=self.url, output_dir=save_path.parent, filename=save_path.name, @@ -382,7 +393,7 @@ class Track: maximum_size: int = 20000, url: Optional[str] = None, byte_range: Optional[str] = None, - session: Optional[requests.Session] = None + session: Optional[Session] = None ) -> bytes: """ Get the Track's Initial Segment Data Stream. @@ -412,8 +423,8 @@ class Track: raise TypeError(f"Expected url to be a {str}, not {type(url)}") if not isinstance(byte_range, (str, type(None))): raise TypeError(f"Expected byte_range to be a {str}, not {type(byte_range)}") - if not isinstance(session, (requests.Session, type(None))): - raise TypeError(f"Expected session to be a {requests.Session}, not {type(session)}") + if not isinstance(session, (Session, type(None))): + raise TypeError(f"Expected session to be a {Session}, not {type(session)}") if not url: if self.descriptor != self.Descriptor.URL: @@ -423,7 +434,7 @@ class Track: url = self.url if not session: - session = requests.Session() + session = Session() content_length = maximum_size