feat(Track): Allow Track to choose downloader to use

The downloader property must be a Callable of the same signature as the aria2c, curl_impersonate, and requests downloader functions. You can pass it these functions by importing, or a custom function of a matching signature.

Note: It will still override the chosen downloader and use a fallback one in the case of using aria2c downloader but the download uses the HTTP Range header.

Closes #70
This commit is contained in:
rlaphoenix 2024-03-08 16:48:44 +00:00
parent ba801739fe
commit 423ff289db
4 changed files with 27 additions and 26 deletions

View File

@ -1,13 +1,5 @@
from ..config import config
from .aria2c import aria2c from .aria2c import aria2c
from .curl_impersonate import curl_impersonate from .curl_impersonate import curl_impersonate
from .requests import requests from .requests import requests
downloader = { __all__ = ("aria2c", "curl_impersonate", "requests")
"aria2c": aria2c,
"curl_impersonate": curl_impersonate,
"requests": requests
}[config.downloader]
__all__ = ("downloader", "aria2c", "curl_impersonate", "requests")

View File

@ -22,7 +22,6 @@ from pywidevine.pssh import PSSH
from requests import Session from requests import Session
from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack
from devine.core.downloaders import downloader
from devine.core.downloaders import requests as requests_downloader from devine.core.downloaders import requests as requests_downloader
from devine.core.drm import Widevine from devine.core.drm import Widevine
from devine.core.tracks import Audio, Subtitle, Tracks, Video from devine.core.tracks import Audio, Subtitle, Tracks, Video
@ -452,12 +451,12 @@ class DASH:
progress(total=len(segments)) progress(total=len(segments))
downloader_ = downloader downloader = track.downloader
if downloader.__name__ == "aria2c" and any(bytes_range is not None for url, bytes_range in segments): if downloader.__name__ == "aria2c" and any(bytes_range is not None for url, bytes_range in segments):
# aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader # aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader
downloader_ = requests_downloader downloader = requests_downloader
for status_update in downloader_( for status_update in downloader(
urls=[ urls=[
{ {
"url": url, "url": url,

View File

@ -20,7 +20,6 @@ from pywidevine.pssh import PSSH
from requests import Session from requests import Session
from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, AnyTrack
from devine.core.downloaders import downloader
from devine.core.downloaders import requests as requests_downloader from devine.core.downloaders import requests as requests_downloader
from devine.core.drm import DRM_T, ClearKey, Widevine from devine.core.drm import DRM_T, ClearKey, Widevine
from devine.core.tracks import Audio, Subtitle, Tracks, Video from devine.core.tracks import Audio, Subtitle, Tracks, Video
@ -247,7 +246,7 @@ class HLS:
total_segments = len(master.segments) - len(unwanted_segments) total_segments = len(master.segments) - len(unwanted_segments)
progress(total=total_segments) progress(total=total_segments)
downloader_ = downloader downloader = track.downloader
urls: list[dict[str, Any]] = [] urls: list[dict[str, Any]] = []
range_offset = 0 range_offset = 0
@ -256,9 +255,9 @@ class HLS:
continue continue
if segment.byterange: if segment.byterange:
if downloader_.__name__ == "aria2c": if downloader.__name__ == "aria2c":
# aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader # aria2(c) is shit and doesn't support the Range header, fallback to the requests downloader
downloader_ = requests_downloader downloader = requests_downloader
byte_range = HLS.calculate_byte_range(segment.byterange, range_offset) byte_range = HLS.calculate_byte_range(segment.byterange, range_offset)
range_offset = byte_range.split("-")[0] range_offset = byte_range.split("-")[0]
else: else:
@ -273,7 +272,7 @@ class HLS:
segment_save_dir = save_dir / "segments" segment_save_dir = save_dir / "segments"
for status_update in downloader_( for status_update in downloader(
urls=urls, urls=urls,
output_dir=segment_save_dir, output_dir=segment_save_dir,
filename="{i:0%d}{ext}" % len(str(len(urls))), filename="{i:0%d}{ext}" % len(str(len(urls))),

View File

@ -13,12 +13,12 @@ from uuid import UUID
from zlib import crc32 from zlib import crc32
import m3u8 import m3u8
import requests
from langcodes import Language from langcodes import Language
from requests import Session
from devine.core.config import config from devine.core.config import config
from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, TERRITORY_MAP from devine.core.constants import DOWNLOAD_CANCELLED, DOWNLOAD_LICENCE_ONLY, TERRITORY_MAP
from devine.core.downloaders import downloader from devine.core.downloaders import aria2c, curl_impersonate, requests
from devine.core.drm import DRM_T, Widevine from devine.core.drm import DRM_T, Widevine
from devine.core.utilities import get_binary_path, get_boxes, try_ensure_utf8 from devine.core.utilities import get_binary_path, get_boxes, try_ensure_utf8
from devine.core.utils.subprocess import ffprobe from devine.core.utils.subprocess import ffprobe
@ -40,6 +40,7 @@ class Track:
name: Optional[str] = None, name: Optional[str] = None,
drm: Optional[Iterable[DRM_T]] = None, drm: Optional[Iterable[DRM_T]] = None,
edition: Optional[str] = None, edition: Optional[str] = None,
downloader: Optional[Callable] = None,
data: Optional[dict] = None, data: Optional[dict] = None,
id_: Optional[str] = None, id_: Optional[str] = None,
) -> None: ) -> None:
@ -59,6 +60,8 @@ class Track:
raise TypeError(f"Expected id_ to be a {str}, not {type(id_)}") raise TypeError(f"Expected id_ to be a {str}, not {type(id_)}")
if not isinstance(edition, (str, type(None))): if not isinstance(edition, (str, type(None))):
raise TypeError(f"Expected edition to be a {str}, not {type(edition)}") raise TypeError(f"Expected edition to be a {str}, not {type(edition)}")
if not isinstance(downloader, (Callable, type(None))):
raise TypeError(f"Expected downloader to be a {Callable}, not {type(downloader)}")
if not isinstance(data, (dict, type(None))): if not isinstance(data, (dict, type(None))):
raise TypeError(f"Expected data to be a {dict}, not {type(data)}") raise TypeError(f"Expected data to be a {dict}, not {type(data)}")
@ -72,6 +75,13 @@ class Track:
except TypeError: except TypeError:
raise TypeError(f"Expected drm to be an iterable, not {type(drm)}") raise TypeError(f"Expected drm to be an iterable, not {type(drm)}")
if downloader is None:
downloader = {
"aria2c": aria2c,
"curl_impersonate": curl_impersonate,
"requests": requests
}[config.downloader]
self.path: Optional[Path] = None self.path: Optional[Path] = None
self.url = url self.url = url
self.language = Language.get(language) self.language = Language.get(language)
@ -81,6 +91,7 @@ class Track:
self.name = name self.name = name
self.drm = drm self.drm = drm
self.edition: str = edition self.edition: str = edition
self.downloader = downloader
self.data = data or {} self.data = data or {}
if not id_: if not id_:
@ -116,7 +127,7 @@ class Track:
def download( def download(
self, self,
session: requests.Session, session: Session,
prepare_drm: partial, prepare_drm: partial,
progress: Optional[partial] = None progress: Optional[partial] = None
): ):
@ -213,7 +224,7 @@ class Track:
if DOWNLOAD_LICENCE_ONLY.is_set(): if DOWNLOAD_LICENCE_ONLY.is_set():
progress(downloaded="[yellow]SKIPPED") progress(downloaded="[yellow]SKIPPED")
else: else:
for status_update in downloader( for status_update in self.downloader(
urls=self.url, urls=self.url,
output_dir=save_path.parent, output_dir=save_path.parent,
filename=save_path.name, filename=save_path.name,
@ -382,7 +393,7 @@ class Track:
maximum_size: int = 20000, maximum_size: int = 20000,
url: Optional[str] = None, url: Optional[str] = None,
byte_range: Optional[str] = None, byte_range: Optional[str] = None,
session: Optional[requests.Session] = None session: Optional[Session] = None
) -> bytes: ) -> bytes:
""" """
Get the Track's Initial Segment Data Stream. Get the Track's Initial Segment Data Stream.
@ -412,8 +423,8 @@ class Track:
raise TypeError(f"Expected url to be a {str}, not {type(url)}") raise TypeError(f"Expected url to be a {str}, not {type(url)}")
if not isinstance(byte_range, (str, type(None))): if not isinstance(byte_range, (str, type(None))):
raise TypeError(f"Expected byte_range to be a {str}, not {type(byte_range)}") raise TypeError(f"Expected byte_range to be a {str}, not {type(byte_range)}")
if not isinstance(session, (requests.Session, type(None))): if not isinstance(session, (Session, type(None))):
raise TypeError(f"Expected session to be a {requests.Session}, not {type(session)}") raise TypeError(f"Expected session to be a {Session}, not {type(session)}")
if not url: if not url:
if self.descriptor != self.Descriptor.URL: if self.descriptor != self.Descriptor.URL:
@ -423,7 +434,7 @@ class Track:
url = self.url url = self.url
if not session: if not session:
session = requests.Session() session = Session()
content_length = maximum_size content_length = maximum_size