Various typing/linting fixes and improvements

This commit is contained in:
rlaphoenix 2023-11-08 21:25:28 +00:00
parent 97ec2e1c60
commit 0e6aa1d5e8
6 changed files with 77 additions and 82 deletions

View File

@ -510,8 +510,7 @@ class Cdm:
input_file = Path(input_file) input_file = Path(input_file)
output_file = Path(output_file) output_file = Path(output_file)
if temp_dir: temp_dir_ = Path(temp_dir) if temp_dir else None
temp_dir = Path(temp_dir)
if not input_file.is_file(): if not input_file.is_file():
raise FileNotFoundError(f"Input file does not exist, {input_file}") raise FileNotFoundError(f"Input file does not exist, {input_file}")
@ -545,9 +544,9 @@ class Cdm:
]) ])
] ]
if temp_dir: if temp_dir_:
temp_dir.mkdir(parents=True, exist_ok=True) temp_dir_.mkdir(parents=True, exist_ok=True)
args.extend(["--temp_dir", temp_dir]) args.extend(["--temp_dir", str(temp_dir_)])
return subprocess.check_call([executable, *args]) return subprocess.check_call([executable, *args])
@ -555,8 +554,8 @@ class Cdm:
def encrypt_client_id( def encrypt_client_id(
client_id: ClientIdentification, client_id: ClientIdentification,
service_certificate: Union[SignedDrmCertificate, DrmCertificate], service_certificate: Union[SignedDrmCertificate, DrmCertificate],
key: bytes = None, key: Optional[bytes] = None,
iv: bytes = None iv: Optional[bytes] = None
) -> EncryptedClientIdentification: ) -> EncryptedClientIdentification:
"""Encrypt the Client ID with the Service's Privacy Certificate.""" """Encrypt the Client ID with the Service's Privacy Certificate."""
privacy_key = key or get_random_bytes(16) privacy_key = key or get_random_bytes(16)

View File

@ -199,36 +199,36 @@ class Device:
raise ValueError("Device Data does not seem to be a WVD file (v0).") raise ValueError("Device Data does not seem to be a WVD file (v0).")
if header.version == 1: # v1 to v2 if header.version == 1: # v1 to v2
data = _Structures.v1.parse(data) v1_struct = _Structures.v1.parse(data)
data.version = 2 # update version to 2 to allow loading v1_struct.version = 2 # update version to 2 to allow loading
data.flags = Container() # blank flags that may have been used in v1 v1_struct.flags = Container() # blank flags that may have been used in v1
vmp = FileHashes() vmp = FileHashes()
if data.vmp: if v1_struct.vmp:
try: try:
vmp.ParseFromString(data.vmp) vmp.ParseFromString(v1_struct.vmp)
if vmp.SerializeToString() != data.vmp: if vmp.SerializeToString() != v1_struct.vmp:
raise DecodeError("partial parse") raise DecodeError("partial parse")
except DecodeError as e: except DecodeError as e:
raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}") raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}")
data.vmp = vmp v1_struct.vmp = vmp
client_id = ClientIdentification() client_id = ClientIdentification()
try: try:
client_id.ParseFromString(data.client_id) client_id.ParseFromString(v1_struct.client_id)
if client_id.SerializeToString() != data.client_id: if client_id.SerializeToString() != v1_struct.client_id:
raise DecodeError("partial parse") raise DecodeError("partial parse")
except DecodeError as e: except DecodeError as e:
raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}") raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}")
new_vmp_data = data.vmp.SerializeToString() new_vmp_data = v1_struct.vmp.SerializeToString()
if client_id.vmp_data and client_id.vmp_data != new_vmp_data: if client_id.vmp_data and client_id.vmp_data != new_vmp_data:
logging.getLogger("migrate").warning("Client ID already has Verified Media Path data") logging.getLogger("migrate").warning("Client ID already has Verified Media Path data")
client_id.vmp_data = new_vmp_data client_id.vmp_data = new_vmp_data
data.client_id = client_id.SerializeToString() v1_struct.client_id = client_id.SerializeToString()
try: try:
data = _Structures.v2.build(data) data = _Structures.v2.build(v1_struct)
except ConstructError as e: except ConstructError as e:
raise ValueError(f"Migration failed, {e}") raise ValueError(f"Migration failed, {e}")

View File

@ -26,10 +26,8 @@ def main(version: bool, debug: bool) -> None:
logging.basicConfig(level=logging.DEBUG if debug else logging.INFO) logging.basicConfig(level=logging.DEBUG if debug else logging.INFO)
log = logging.getLogger() log = logging.getLogger()
copyright_years = 2022
current_year = datetime.now().year current_year = datetime.now().year
if copyright_years != current_year: copyright_years = f"2022-{current_year}"
copyright_years = f"{copyright_years}-{current_year}"
log.info("pywidevine version %s Copyright (c) %s rlaphoenix", __version__, copyright_years) log.info("pywidevine version %s Copyright (c) %s rlaphoenix", __version__, copyright_years)
log.info("https://github.com/rlaphoenix/pywidevine") log.info("https://github.com/rlaphoenix/pywidevine")
@ -38,15 +36,15 @@ def main(version: bool, debug: bool) -> None:
@main.command(name="license") @main.command(name="license")
@click.argument("device", type=Path) @click.argument("device_path", type=Path)
@click.argument("pssh", type=str) @click.argument("pssh", type=PSSH)
@click.argument("server", type=str) @click.argument("server", type=str)
@click.option("-t", "--type", "type_", type=click.Choice(LicenseType.keys(), case_sensitive=False), @click.option("-t", "--type", "type_", type=click.Choice(LicenseType.keys(), case_sensitive=False),
default="STREAMING", default="STREAMING",
help="License Type to Request.") help="License Type to Request.")
@click.option("-p", "--privacy", is_flag=True, default=False, @click.option("-p", "--privacy", is_flag=True, default=False,
help="Use Privacy Mode, off by default.") help="Use Privacy Mode, off by default.")
def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool): def license_(device_path: Path, pssh: PSSH, server: str, type_: str, privacy: bool) -> None:
""" """
Make a License Request for PSSH to SERVER using DEVICE. Make a License Request for PSSH to SERVER using DEVICE.
It will return a list of all keys within the returned license. It will return a list of all keys within the returned license.
@ -65,11 +63,8 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
""" """
log = logging.getLogger("license") log = logging.getLogger("license")
# prepare pssh
pssh = PSSH(pssh)
# load device # load device
device = Device.load(device) device = Device.load(device_path)
log.info("[+] Loaded Device (%s L%s)", device.system_id, device.security_level) log.info("[+] Loaded Device (%s L%s)", device.system_id, device.security_level)
log.debug(device) log.debug(device)
@ -84,18 +79,18 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
if privacy: if privacy:
# get service cert for license server via cert challenge # get service cert for license server via cert challenge
service_cert = requests.post( service_cert_res = requests.post(
url=server, url=server,
data=cdm.service_certificate_challenge data=cdm.service_certificate_challenge
) )
if service_cert.status_code != 200: if service_cert_res.status_code != 200:
log.error( log.error(
"[-] Failed to get Service Privacy Certificate: [%s] %s", "[-] Failed to get Service Privacy Certificate: [%s] %s",
service_cert.status_code, service_cert_res.status_code,
service_cert.text service_cert_res.text
) )
return return
service_cert = service_cert.content service_cert = service_cert_res.content
provider_id = cdm.set_service_certificate(session_id, service_cert) provider_id = cdm.set_service_certificate(session_id, service_cert)
log.info("[+] Set Service Privacy Certificate: %s", provider_id) log.info("[+] Set Service Privacy Certificate: %s", provider_id)
log.debug(service_cert) log.debug(service_cert)
@ -107,14 +102,14 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
log.debug(challenge) log.debug(challenge)
# send license challenge # send license challenge
licence = requests.post( license_res = requests.post(
url=server, url=server,
data=challenge data=challenge
) )
if licence.status_code != 200: if license_res.status_code != 200:
log.error("[-] Failed to send challenge: [%s] %s", licence.status_code, licence.text) log.error("[-] Failed to send challenge: [%s] %s", license_res.status_code, license_res.text)
return return
licence = licence.content licence = license_res.content
log.info("[+] Got License Message") log.info("[+] Got License Message")
log.debug(licence) log.debug(licence)
@ -135,7 +130,7 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
@click.option("-p", "--privacy", is_flag=True, default=False, @click.option("-p", "--privacy", is_flag=True, default=False,
help="Use Privacy Mode, off by default.") help="Use Privacy Mode, off by default.")
@click.pass_context @click.pass_context
def test(ctx: click.Context, device: Path, privacy: bool): def test(ctx: click.Context, device: Path, privacy: bool) -> None:
""" """
Test the CDM code by getting Content Keys for Bitmovin's Art of Motion example. Test the CDM code by getting Content Keys for Bitmovin's Art of Motion example.
https://bitmovin.com/demos/drm https://bitmovin.com/demos/drm
@ -161,7 +156,7 @@ def test(ctx: click.Context, device: Path, privacy: bool):
# it will print information as it goes to the terminal # it will print information as it goes to the terminal
ctx.invoke( ctx.invoke(
license_, license_,
device=device, device_path=device,
pssh=pssh, pssh=pssh,
server=license_server, server=license_server,
type_=LicenseType.Name(license_type), type_=LicenseType.Name(license_type),
@ -382,10 +377,10 @@ def migrate(ctx: click.Context, path: Path) -> None:
@main.command("serve", short_help="Serve your local CDM and Widevine Devices Remotely.") @main.command("serve", short_help="Serve your local CDM and Widevine Devices Remotely.")
@click.argument("config", type=Path) @click.argument("config_path", type=Path)
@click.option("-h", "--host", type=str, default="127.0.0.1", help="Host to serve from.") @click.option("-h", "--host", type=str, default="127.0.0.1", help="Host to serve from.")
@click.option("-p", "--port", type=int, default=8786, help="Port to serve from.") @click.option("-p", "--port", type=int, default=8786, help="Port to serve from.")
def serve_(config: Path, host: str, port: int): def serve_(config_path: Path, host: str, port: int) -> None:
""" """
Serve your local CDM and Widevine Devices Remotely. Serve your local CDM and Widevine Devices Remotely.
@ -400,5 +395,5 @@ def serve_(config: Path, host: str, port: int):
from pywidevine import serve # isort:skip from pywidevine import serve # isort:skip
import yaml # isort:skip import yaml # isort:skip
config = yaml.safe_load(config.read_text(encoding="utf8")) config = yaml.safe_load(config_path.read_text(encoding="utf8"))
serve.run(config, host, port) serve.run(config, host, port)

View File

@ -82,17 +82,17 @@ class PSSH:
box = Box.parse(data) box = Box.parse(data)
except (IOError, construct.ConstructError): # not a box except (IOError, construct.ConstructError): # not a box
try: try:
cenc_header = WidevinePsshData() widevine_pssh_data = WidevinePsshData()
cenc_header.ParseFromString(data) widevine_pssh_data.ParseFromString(data)
cenc_header = cenc_header.SerializeToString() data_serialized = widevine_pssh_data.SerializeToString()
if cenc_header != data: # not actually a WidevinePsshData if data_serialized != data: # not actually a WidevinePsshData
raise DecodeError() raise DecodeError()
box = Box.parse(Box.build(dict( box = Box.parse(Box.build(dict(
type=b"pssh", type=b"pssh",
version=0, version=0,
flags=0, flags=0,
system_ID=PSSH.SystemId.Widevine, system_ID=PSSH.SystemId.Widevine,
init_data=cenc_header init_data=data_serialized
))) )))
except DecodeError: # not a widevine cenc header except DecodeError: # not a widevine cenc header
if "</WRMHEADER>".encode("utf-16-le") in data: if "</WRMHEADER>".encode("utf-16-le") in data:
@ -307,16 +307,16 @@ class PSSH:
if self.system_id == PSSH.SystemId.Widevine: if self.system_id == PSSH.SystemId.Widevine:
raise ValueError("This is already a Widevine PSSH") raise ValueError("This is already a Widevine PSSH")
cenc_header = WidevinePsshData() widevine_pssh_data = WidevinePsshData()
cenc_header.algorithm = 1 # 0=Clear, 1=AES-CTR widevine_pssh_data.algorithm = WidevinePsshData.Algorithm.Value("AESCTR")
cenc_header.key_ids[:] = [x.bytes for x in self.key_ids] widevine_pssh_data.key_ids[:] = [x.bytes for x in self.key_ids]
if self.version == 1: if self.version == 1:
# ensure both cenc header and box has same Key IDs # ensure both cenc header and box has same Key IDs
# v1 uses both this and within init data for basically no reason # v1 uses both this and within init data for basically no reason
self.__key_ids = self.key_ids self.__key_ids = self.key_ids
self.init_data = cenc_header.SerializeToString() self.init_data = widevine_pssh_data.SerializeToString()
self.system_id = PSSH.SystemId.Widevine self.system_id = PSSH.SystemId.Widevine
def to_playready( def to_playready(

View File

@ -86,10 +86,10 @@ class RemoteCdm(Cdm):
server = r.headers.get("Server") server = r.headers.get("Server")
if not server or "pywidevine serve" not in server.lower(): if not server or "pywidevine serve" not in server.lower():
raise ValueError(f"This Remote CDM API does not seem to be a pywidevine serve API ({server}).") raise ValueError(f"This Remote CDM API does not seem to be a pywidevine serve API ({server}).")
server_version = re.search(r"pywidevine serve v([\d.]+)", server, re.IGNORECASE) server_version_re = re.search(r"pywidevine serve v([\d.]+)", server, re.IGNORECASE)
if not server_version: if not server_version_re:
raise ValueError("The pywidevine server API is not stating the version correctly, cannot continue.") raise ValueError("The pywidevine server API is not stating the version correctly, cannot continue.")
server_version = server_version.group(1) server_version = server_version_re.group(1)
if server_version < "1.4.3": if server_version < "1.4.3":
raise ValueError(f"This pywidevine serve API version ({server_version}) is not supported.") raise ValueError(f"This pywidevine serve API version ({server_version}) is not supported.")

View File

@ -1,8 +1,9 @@
import base64 import base64
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Any, Optional, Union
from aiohttp.typedefs import Handler
from google.protobuf.message import DecodeError from google.protobuf.message import DecodeError
from pywidevine.pssh import PSSH from pywidevine.pssh import PSSH
@ -26,8 +27,8 @@ from pywidevine.exceptions import (InvalidContext, InvalidInitData, InvalidLicen
routes = web.RouteTableDef() routes = web.RouteTableDef()
async def _startup(app: web.Application): async def _startup(app: web.Application) -> None:
app["cdms"]: dict[tuple[str, str], Cdm] = {} app["cdms"] = {}
app["config"]["devices"] = { app["config"]["devices"] = {
path.stem: path path.stem: path
for x in app["config"]["devices"] for x in app["config"]["devices"]
@ -38,7 +39,7 @@ async def _startup(app: web.Application):
raise FileNotFoundError(f"Device file does not exist: {device}") raise FileNotFoundError(f"Device file does not exist: {device}")
async def _cleanup(app: web.Application): async def _cleanup(app: web.Application) -> None:
app["cdms"].clear() app["cdms"].clear()
del app["cdms"] del app["cdms"]
app["config"].clear() app["config"].clear()
@ -46,7 +47,7 @@ async def _cleanup(app: web.Application):
@routes.get("/") @routes.get("/")
async def ping(_) -> web.Response: async def ping(_: Any) -> web.Response:
return web.json_response({ return web.json_response({
"status": 200, "status": 200,
"message": "Pong!" "message": "Pong!"
@ -211,13 +212,15 @@ async def get_service_certificate(request: web.Request) -> web.Response:
}, status=400) }, status=400)
if service_certificate: if service_certificate:
service_certificate = base64.b64encode(service_certificate.SerializeToString()).decode() service_certificate_b64 = base64.b64encode(service_certificate.SerializeToString()).decode()
else:
service_certificate_b64 = None
return web.json_response({ return web.json_response({
"status": 200, "status": 200,
"message": "Successfully got the Service Certificate.", "message": "Successfully got the Service Certificate.",
"data": { "data": {
"service_certificate": service_certificate "service_certificate": service_certificate_b64
} }
}) })
@ -366,7 +369,7 @@ async def get_keys(request: web.Request) -> web.Response:
session_id = bytes.fromhex(body["session_id"]) session_id = bytes.fromhex(body["session_id"])
# get key type # get key type
key_type = request.match_info["key_type"] key_type: Optional[str] = request.match_info["key_type"]
if key_type == "ALL": if key_type == "ALL":
key_type = None key_type = None
@ -414,26 +417,24 @@ async def get_keys(request: web.Request) -> web.Response:
@web.middleware @web.middleware
async def authentication(request: web.Request, handler) -> web.Response: async def authentication(request: web.Request, handler: Handler) -> web.Response:
response = None
if request.path != "/":
secret_key = request.headers.get("X-Secret-Key") secret_key = request.headers.get("X-Secret-Key")
if not secret_key:
if request.path != "/" and not secret_key:
request.app.logger.debug(f"{request.remote} did not provide authorization.") request.app.logger.debug(f"{request.remote} did not provide authorization.")
response = web.json_response({ response = web.json_response({
"status": "401", "status": "401",
"message": "Secret Key is Empty." "message": "Secret Key is Empty."
}, status=401) }, status=401)
elif secret_key not in request.app["config"]["users"]: elif request.path != "/" and secret_key not in request.app["config"]["users"]:
request.app.logger.debug(f"{request.remote} failed authentication with '{secret_key}'.") request.app.logger.debug(f"{request.remote} failed authentication with '{secret_key}'.")
response = web.json_response({ response = web.json_response({
"status": "401", "status": "401",
"message": "Secret Key is Invalid, the Key is case-sensitive." "message": "Secret Key is Invalid, the Key is case-sensitive."
}, status=401) }, status=401)
else:
if response is None:
try: try:
response = await handler(request) response = await handler(request) # type: ignore[assignment]
except web.HTTPException as e: except web.HTTPException as e:
request.app.logger.error(f"An unexpected error has occurred, {e}") request.app.logger.error(f"An unexpected error has occurred, {e}")
response = web.json_response({ response = web.json_response({
@ -448,7 +449,7 @@ async def authentication(request: web.Request, handler) -> web.Response:
return response return response
def run(config: dict, host: Optional[Union[str, web.HostSequence]] = None, port: Optional[int] = None): def run(config: dict, host: Optional[Union[str, web.HostSequence]] = None, port: Optional[int] = None) -> None:
app = web.Application(middlewares=[authentication]) app = web.Application(middlewares=[authentication])
app.on_startup.append(_startup) app.on_startup.append(_startup)
app.on_cleanup.append(_cleanup) app.on_cleanup.append(_cleanup)