From 0e6aa1d5e8908093ef15abe8c4b5c015a069e469 Mon Sep 17 00:00:00 2001 From: rlaphoenix Date: Wed, 8 Nov 2023 21:25:28 +0000 Subject: [PATCH] Various typing/linting fixes and improvements --- pywidevine/cdm.py | 13 +++++----- pywidevine/device.py | 24 +++++++++--------- pywidevine/main.py | 43 ++++++++++++++------------------ pywidevine/pssh.py | 18 +++++++------- pywidevine/remotecdm.py | 6 ++--- pywidevine/serve.py | 55 +++++++++++++++++++++-------------------- 6 files changed, 77 insertions(+), 82 deletions(-) diff --git a/pywidevine/cdm.py b/pywidevine/cdm.py index 0ffd23f..048c89c 100644 --- a/pywidevine/cdm.py +++ b/pywidevine/cdm.py @@ -510,8 +510,7 @@ class Cdm: input_file = Path(input_file) output_file = Path(output_file) - if temp_dir: - temp_dir = Path(temp_dir) + temp_dir_ = Path(temp_dir) if temp_dir else None if not input_file.is_file(): raise FileNotFoundError(f"Input file does not exist, {input_file}") @@ -545,9 +544,9 @@ class Cdm: ]) ] - if temp_dir: - temp_dir.mkdir(parents=True, exist_ok=True) - args.extend(["--temp_dir", temp_dir]) + if temp_dir_: + temp_dir_.mkdir(parents=True, exist_ok=True) + args.extend(["--temp_dir", str(temp_dir_)]) return subprocess.check_call([executable, *args]) @@ -555,8 +554,8 @@ class Cdm: def encrypt_client_id( client_id: ClientIdentification, service_certificate: Union[SignedDrmCertificate, DrmCertificate], - key: bytes = None, - iv: bytes = None + key: Optional[bytes] = None, + iv: Optional[bytes] = None ) -> EncryptedClientIdentification: """Encrypt the Client ID with the Service's Privacy Certificate.""" privacy_key = key or get_random_bytes(16) diff --git a/pywidevine/device.py b/pywidevine/device.py index d4d2c5d..1deeaf4 100644 --- a/pywidevine/device.py +++ b/pywidevine/device.py @@ -199,36 +199,36 @@ class Device: raise ValueError("Device Data does not seem to be a WVD file (v0).") if header.version == 1: # v1 to v2 - data = _Structures.v1.parse(data) - data.version = 2 # update version to 2 to allow loading - data.flags = Container() # blank flags that may have been used in v1 + v1_struct = _Structures.v1.parse(data) + v1_struct.version = 2 # update version to 2 to allow loading + v1_struct.flags = Container() # blank flags that may have been used in v1 vmp = FileHashes() - if data.vmp: + if v1_struct.vmp: try: - vmp.ParseFromString(data.vmp) - if vmp.SerializeToString() != data.vmp: + vmp.ParseFromString(v1_struct.vmp) + if vmp.SerializeToString() != v1_struct.vmp: raise DecodeError("partial parse") except DecodeError as e: raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}") - data.vmp = vmp + v1_struct.vmp = vmp client_id = ClientIdentification() try: - client_id.ParseFromString(data.client_id) - if client_id.SerializeToString() != data.client_id: + client_id.ParseFromString(v1_struct.client_id) + if client_id.SerializeToString() != v1_struct.client_id: raise DecodeError("partial parse") except DecodeError as e: raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}") - new_vmp_data = data.vmp.SerializeToString() + new_vmp_data = v1_struct.vmp.SerializeToString() if client_id.vmp_data and client_id.vmp_data != new_vmp_data: logging.getLogger("migrate").warning("Client ID already has Verified Media Path data") client_id.vmp_data = new_vmp_data - data.client_id = client_id.SerializeToString() + v1_struct.client_id = client_id.SerializeToString() try: - data = _Structures.v2.build(data) + data = _Structures.v2.build(v1_struct) except ConstructError as e: raise ValueError(f"Migration failed, {e}") diff --git a/pywidevine/main.py b/pywidevine/main.py index f6a52c4..09c29d2 100644 --- a/pywidevine/main.py +++ b/pywidevine/main.py @@ -26,10 +26,8 @@ def main(version: bool, debug: bool) -> None: logging.basicConfig(level=logging.DEBUG if debug else logging.INFO) log = logging.getLogger() - copyright_years = 2022 current_year = datetime.now().year - if copyright_years != current_year: - copyright_years = f"{copyright_years}-{current_year}" + copyright_years = f"2022-{current_year}" log.info("pywidevine version %s Copyright (c) %s rlaphoenix", __version__, copyright_years) log.info("https://github.com/rlaphoenix/pywidevine") @@ -38,15 +36,15 @@ def main(version: bool, debug: bool) -> None: @main.command(name="license") -@click.argument("device", type=Path) -@click.argument("pssh", type=str) +@click.argument("device_path", type=Path) +@click.argument("pssh", type=PSSH) @click.argument("server", type=str) @click.option("-t", "--type", "type_", type=click.Choice(LicenseType.keys(), case_sensitive=False), default="STREAMING", help="License Type to Request.") @click.option("-p", "--privacy", is_flag=True, default=False, help="Use Privacy Mode, off by default.") -def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool): +def license_(device_path: Path, pssh: PSSH, server: str, type_: str, privacy: bool) -> None: """ Make a License Request for PSSH to SERVER using DEVICE. It will return a list of all keys within the returned license. @@ -65,11 +63,8 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool): """ log = logging.getLogger("license") - # prepare pssh - pssh = PSSH(pssh) - # load device - device = Device.load(device) + device = Device.load(device_path) log.info("[+] Loaded Device (%s L%s)", device.system_id, device.security_level) log.debug(device) @@ -84,18 +79,18 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool): if privacy: # get service cert for license server via cert challenge - service_cert = requests.post( + service_cert_res = requests.post( url=server, data=cdm.service_certificate_challenge ) - if service_cert.status_code != 200: + if service_cert_res.status_code != 200: log.error( "[-] Failed to get Service Privacy Certificate: [%s] %s", - service_cert.status_code, - service_cert.text + service_cert_res.status_code, + service_cert_res.text ) return - service_cert = service_cert.content + service_cert = service_cert_res.content provider_id = cdm.set_service_certificate(session_id, service_cert) log.info("[+] Set Service Privacy Certificate: %s", provider_id) log.debug(service_cert) @@ -107,14 +102,14 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool): log.debug(challenge) # send license challenge - licence = requests.post( + license_res = requests.post( url=server, data=challenge ) - if licence.status_code != 200: - log.error("[-] Failed to send challenge: [%s] %s", licence.status_code, licence.text) + if license_res.status_code != 200: + log.error("[-] Failed to send challenge: [%s] %s", license_res.status_code, license_res.text) return - licence = licence.content + licence = license_res.content log.info("[+] Got License Message") log.debug(licence) @@ -135,7 +130,7 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool): @click.option("-p", "--privacy", is_flag=True, default=False, help="Use Privacy Mode, off by default.") @click.pass_context -def test(ctx: click.Context, device: Path, privacy: bool): +def test(ctx: click.Context, device: Path, privacy: bool) -> None: """ Test the CDM code by getting Content Keys for Bitmovin's Art of Motion example. https://bitmovin.com/demos/drm @@ -161,7 +156,7 @@ def test(ctx: click.Context, device: Path, privacy: bool): # it will print information as it goes to the terminal ctx.invoke( license_, - device=device, + device_path=device, pssh=pssh, server=license_server, type_=LicenseType.Name(license_type), @@ -382,10 +377,10 @@ def migrate(ctx: click.Context, path: Path) -> None: @main.command("serve", short_help="Serve your local CDM and Widevine Devices Remotely.") -@click.argument("config", type=Path) +@click.argument("config_path", type=Path) @click.option("-h", "--host", type=str, default="127.0.0.1", help="Host to serve from.") @click.option("-p", "--port", type=int, default=8786, help="Port to serve from.") -def serve_(config: Path, host: str, port: int): +def serve_(config_path: Path, host: str, port: int) -> None: """ Serve your local CDM and Widevine Devices Remotely. @@ -400,5 +395,5 @@ def serve_(config: Path, host: str, port: int): from pywidevine import serve # isort:skip import yaml # isort:skip - config = yaml.safe_load(config.read_text(encoding="utf8")) + config = yaml.safe_load(config_path.read_text(encoding="utf8")) serve.run(config, host, port) diff --git a/pywidevine/pssh.py b/pywidevine/pssh.py index 32c619f..14096e8 100644 --- a/pywidevine/pssh.py +++ b/pywidevine/pssh.py @@ -82,17 +82,17 @@ class PSSH: box = Box.parse(data) except (IOError, construct.ConstructError): # not a box try: - cenc_header = WidevinePsshData() - cenc_header.ParseFromString(data) - cenc_header = cenc_header.SerializeToString() - if cenc_header != data: # not actually a WidevinePsshData + widevine_pssh_data = WidevinePsshData() + widevine_pssh_data.ParseFromString(data) + data_serialized = widevine_pssh_data.SerializeToString() + if data_serialized != data: # not actually a WidevinePsshData raise DecodeError() box = Box.parse(Box.build(dict( type=b"pssh", version=0, flags=0, system_ID=PSSH.SystemId.Widevine, - init_data=cenc_header + init_data=data_serialized ))) except DecodeError: # not a widevine cenc header if "".encode("utf-16-le") in data: @@ -307,16 +307,16 @@ class PSSH: if self.system_id == PSSH.SystemId.Widevine: raise ValueError("This is already a Widevine PSSH") - cenc_header = WidevinePsshData() - cenc_header.algorithm = 1 # 0=Clear, 1=AES-CTR - cenc_header.key_ids[:] = [x.bytes for x in self.key_ids] + widevine_pssh_data = WidevinePsshData() + widevine_pssh_data.algorithm = WidevinePsshData.Algorithm.Value("AESCTR") + widevine_pssh_data.key_ids[:] = [x.bytes for x in self.key_ids] if self.version == 1: # ensure both cenc header and box has same Key IDs # v1 uses both this and within init data for basically no reason self.__key_ids = self.key_ids - self.init_data = cenc_header.SerializeToString() + self.init_data = widevine_pssh_data.SerializeToString() self.system_id = PSSH.SystemId.Widevine def to_playready( diff --git a/pywidevine/remotecdm.py b/pywidevine/remotecdm.py index 6c9254a..35fa15d 100644 --- a/pywidevine/remotecdm.py +++ b/pywidevine/remotecdm.py @@ -86,10 +86,10 @@ class RemoteCdm(Cdm): server = r.headers.get("Server") if not server or "pywidevine serve" not in server.lower(): raise ValueError(f"This Remote CDM API does not seem to be a pywidevine serve API ({server}).") - server_version = re.search(r"pywidevine serve v([\d.]+)", server, re.IGNORECASE) - if not server_version: + server_version_re = re.search(r"pywidevine serve v([\d.]+)", server, re.IGNORECASE) + if not server_version_re: raise ValueError("The pywidevine server API is not stating the version correctly, cannot continue.") - server_version = server_version.group(1) + server_version = server_version_re.group(1) if server_version < "1.4.3": raise ValueError(f"This pywidevine serve API version ({server_version}) is not supported.") diff --git a/pywidevine/serve.py b/pywidevine/serve.py index b09f9e7..37a0c35 100644 --- a/pywidevine/serve.py +++ b/pywidevine/serve.py @@ -1,8 +1,9 @@ import base64 import sys from pathlib import Path -from typing import Optional, Union +from typing import Any, Optional, Union +from aiohttp.typedefs import Handler from google.protobuf.message import DecodeError from pywidevine.pssh import PSSH @@ -26,8 +27,8 @@ from pywidevine.exceptions import (InvalidContext, InvalidInitData, InvalidLicen routes = web.RouteTableDef() -async def _startup(app: web.Application): - app["cdms"]: dict[tuple[str, str], Cdm] = {} +async def _startup(app: web.Application) -> None: + app["cdms"] = {} app["config"]["devices"] = { path.stem: path for x in app["config"]["devices"] @@ -38,7 +39,7 @@ async def _startup(app: web.Application): raise FileNotFoundError(f"Device file does not exist: {device}") -async def _cleanup(app: web.Application): +async def _cleanup(app: web.Application) -> None: app["cdms"].clear() del app["cdms"] app["config"].clear() @@ -46,7 +47,7 @@ async def _cleanup(app: web.Application): @routes.get("/") -async def ping(_) -> web.Response: +async def ping(_: Any) -> web.Response: return web.json_response({ "status": 200, "message": "Pong!" @@ -211,13 +212,15 @@ async def get_service_certificate(request: web.Request) -> web.Response: }, status=400) if service_certificate: - service_certificate = base64.b64encode(service_certificate.SerializeToString()).decode() + service_certificate_b64 = base64.b64encode(service_certificate.SerializeToString()).decode() + else: + service_certificate_b64 = None return web.json_response({ "status": 200, "message": "Successfully got the Service Certificate.", "data": { - "service_certificate": service_certificate + "service_certificate": service_certificate_b64 } }) @@ -366,7 +369,7 @@ async def get_keys(request: web.Request) -> web.Response: session_id = bytes.fromhex(body["session_id"]) # get key type - key_type = request.match_info["key_type"] + key_type: Optional[str] = request.match_info["key_type"] if key_type == "ALL": key_type = None @@ -414,26 +417,24 @@ async def get_keys(request: web.Request) -> web.Response: @web.middleware -async def authentication(request: web.Request, handler) -> web.Response: - response = None - if request.path != "/": - secret_key = request.headers.get("X-Secret-Key") - if not secret_key: - request.app.logger.debug(f"{request.remote} did not provide authorization.") - response = web.json_response({ - "status": "401", - "message": "Secret Key is Empty." - }, status=401) - elif secret_key not in request.app["config"]["users"]: - request.app.logger.debug(f"{request.remote} failed authentication with '{secret_key}'.") - response = web.json_response({ - "status": "401", - "message": "Secret Key is Invalid, the Key is case-sensitive." - }, status=401) +async def authentication(request: web.Request, handler: Handler) -> web.Response: + secret_key = request.headers.get("X-Secret-Key") - if response is None: + if request.path != "/" and not secret_key: + request.app.logger.debug(f"{request.remote} did not provide authorization.") + response = web.json_response({ + "status": "401", + "message": "Secret Key is Empty." + }, status=401) + elif request.path != "/" and secret_key not in request.app["config"]["users"]: + request.app.logger.debug(f"{request.remote} failed authentication with '{secret_key}'.") + response = web.json_response({ + "status": "401", + "message": "Secret Key is Invalid, the Key is case-sensitive." + }, status=401) + else: try: - response = await handler(request) + response = await handler(request) # type: ignore[assignment] except web.HTTPException as e: request.app.logger.error(f"An unexpected error has occurred, {e}") response = web.json_response({ @@ -448,7 +449,7 @@ async def authentication(request: web.Request, handler) -> web.Response: return response -def run(config: dict, host: Optional[Union[str, web.HostSequence]] = None, port: Optional[int] = None): +def run(config: dict, host: Optional[Union[str, web.HostSequence]] = None, port: Optional[int] = None) -> None: app = web.Application(middlewares=[authentication]) app.on_startup.append(_startup) app.on_cleanup.append(_cleanup)