Various typing/linting fixes and improvements

This commit is contained in:
rlaphoenix 2023-11-08 21:25:28 +00:00
parent 97ec2e1c60
commit 0e6aa1d5e8
6 changed files with 77 additions and 82 deletions

View File

@ -510,8 +510,7 @@ class Cdm:
input_file = Path(input_file)
output_file = Path(output_file)
if temp_dir:
temp_dir = Path(temp_dir)
temp_dir_ = Path(temp_dir) if temp_dir else None
if not input_file.is_file():
raise FileNotFoundError(f"Input file does not exist, {input_file}")
@ -545,9 +544,9 @@ class Cdm:
])
]
if temp_dir:
temp_dir.mkdir(parents=True, exist_ok=True)
args.extend(["--temp_dir", temp_dir])
if temp_dir_:
temp_dir_.mkdir(parents=True, exist_ok=True)
args.extend(["--temp_dir", str(temp_dir_)])
return subprocess.check_call([executable, *args])
@ -555,8 +554,8 @@ class Cdm:
def encrypt_client_id(
client_id: ClientIdentification,
service_certificate: Union[SignedDrmCertificate, DrmCertificate],
key: bytes = None,
iv: bytes = None
key: Optional[bytes] = None,
iv: Optional[bytes] = None
) -> EncryptedClientIdentification:
"""Encrypt the Client ID with the Service's Privacy Certificate."""
privacy_key = key or get_random_bytes(16)

View File

@ -199,36 +199,36 @@ class Device:
raise ValueError("Device Data does not seem to be a WVD file (v0).")
if header.version == 1: # v1 to v2
data = _Structures.v1.parse(data)
data.version = 2 # update version to 2 to allow loading
data.flags = Container() # blank flags that may have been used in v1
v1_struct = _Structures.v1.parse(data)
v1_struct.version = 2 # update version to 2 to allow loading
v1_struct.flags = Container() # blank flags that may have been used in v1
vmp = FileHashes()
if data.vmp:
if v1_struct.vmp:
try:
vmp.ParseFromString(data.vmp)
if vmp.SerializeToString() != data.vmp:
vmp.ParseFromString(v1_struct.vmp)
if vmp.SerializeToString() != v1_struct.vmp:
raise DecodeError("partial parse")
except DecodeError as e:
raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}")
data.vmp = vmp
v1_struct.vmp = vmp
client_id = ClientIdentification()
try:
client_id.ParseFromString(data.client_id)
if client_id.SerializeToString() != data.client_id:
client_id.ParseFromString(v1_struct.client_id)
if client_id.SerializeToString() != v1_struct.client_id:
raise DecodeError("partial parse")
except DecodeError as e:
raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}")
new_vmp_data = data.vmp.SerializeToString()
new_vmp_data = v1_struct.vmp.SerializeToString()
if client_id.vmp_data and client_id.vmp_data != new_vmp_data:
logging.getLogger("migrate").warning("Client ID already has Verified Media Path data")
client_id.vmp_data = new_vmp_data
data.client_id = client_id.SerializeToString()
v1_struct.client_id = client_id.SerializeToString()
try:
data = _Structures.v2.build(data)
data = _Structures.v2.build(v1_struct)
except ConstructError as e:
raise ValueError(f"Migration failed, {e}")

View File

@ -26,10 +26,8 @@ def main(version: bool, debug: bool) -> None:
logging.basicConfig(level=logging.DEBUG if debug else logging.INFO)
log = logging.getLogger()
copyright_years = 2022
current_year = datetime.now().year
if copyright_years != current_year:
copyright_years = f"{copyright_years}-{current_year}"
copyright_years = f"2022-{current_year}"
log.info("pywidevine version %s Copyright (c) %s rlaphoenix", __version__, copyright_years)
log.info("https://github.com/rlaphoenix/pywidevine")
@ -38,15 +36,15 @@ def main(version: bool, debug: bool) -> None:
@main.command(name="license")
@click.argument("device", type=Path)
@click.argument("pssh", type=str)
@click.argument("device_path", type=Path)
@click.argument("pssh", type=PSSH)
@click.argument("server", type=str)
@click.option("-t", "--type", "type_", type=click.Choice(LicenseType.keys(), case_sensitive=False),
default="STREAMING",
help="License Type to Request.")
@click.option("-p", "--privacy", is_flag=True, default=False,
help="Use Privacy Mode, off by default.")
def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
def license_(device_path: Path, pssh: PSSH, server: str, type_: str, privacy: bool) -> None:
"""
Make a License Request for PSSH to SERVER using DEVICE.
It will return a list of all keys within the returned license.
@ -65,11 +63,8 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
"""
log = logging.getLogger("license")
# prepare pssh
pssh = PSSH(pssh)
# load device
device = Device.load(device)
device = Device.load(device_path)
log.info("[+] Loaded Device (%s L%s)", device.system_id, device.security_level)
log.debug(device)
@ -84,18 +79,18 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
if privacy:
# get service cert for license server via cert challenge
service_cert = requests.post(
service_cert_res = requests.post(
url=server,
data=cdm.service_certificate_challenge
)
if service_cert.status_code != 200:
if service_cert_res.status_code != 200:
log.error(
"[-] Failed to get Service Privacy Certificate: [%s] %s",
service_cert.status_code,
service_cert.text
service_cert_res.status_code,
service_cert_res.text
)
return
service_cert = service_cert.content
service_cert = service_cert_res.content
provider_id = cdm.set_service_certificate(session_id, service_cert)
log.info("[+] Set Service Privacy Certificate: %s", provider_id)
log.debug(service_cert)
@ -107,14 +102,14 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
log.debug(challenge)
# send license challenge
licence = requests.post(
license_res = requests.post(
url=server,
data=challenge
)
if licence.status_code != 200:
log.error("[-] Failed to send challenge: [%s] %s", licence.status_code, licence.text)
if license_res.status_code != 200:
log.error("[-] Failed to send challenge: [%s] %s", license_res.status_code, license_res.text)
return
licence = licence.content
licence = license_res.content
log.info("[+] Got License Message")
log.debug(licence)
@ -135,7 +130,7 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
@click.option("-p", "--privacy", is_flag=True, default=False,
help="Use Privacy Mode, off by default.")
@click.pass_context
def test(ctx: click.Context, device: Path, privacy: bool):
def test(ctx: click.Context, device: Path, privacy: bool) -> None:
"""
Test the CDM code by getting Content Keys for Bitmovin's Art of Motion example.
https://bitmovin.com/demos/drm
@ -161,7 +156,7 @@ def test(ctx: click.Context, device: Path, privacy: bool):
# it will print information as it goes to the terminal
ctx.invoke(
license_,
device=device,
device_path=device,
pssh=pssh,
server=license_server,
type_=LicenseType.Name(license_type),
@ -382,10 +377,10 @@ def migrate(ctx: click.Context, path: Path) -> None:
@main.command("serve", short_help="Serve your local CDM and Widevine Devices Remotely.")
@click.argument("config", type=Path)
@click.argument("config_path", type=Path)
@click.option("-h", "--host", type=str, default="127.0.0.1", help="Host to serve from.")
@click.option("-p", "--port", type=int, default=8786, help="Port to serve from.")
def serve_(config: Path, host: str, port: int):
def serve_(config_path: Path, host: str, port: int) -> None:
"""
Serve your local CDM and Widevine Devices Remotely.
@ -400,5 +395,5 @@ def serve_(config: Path, host: str, port: int):
from pywidevine import serve # isort:skip
import yaml # isort:skip
config = yaml.safe_load(config.read_text(encoding="utf8"))
config = yaml.safe_load(config_path.read_text(encoding="utf8"))
serve.run(config, host, port)

View File

@ -82,17 +82,17 @@ class PSSH:
box = Box.parse(data)
except (IOError, construct.ConstructError): # not a box
try:
cenc_header = WidevinePsshData()
cenc_header.ParseFromString(data)
cenc_header = cenc_header.SerializeToString()
if cenc_header != data: # not actually a WidevinePsshData
widevine_pssh_data = WidevinePsshData()
widevine_pssh_data.ParseFromString(data)
data_serialized = widevine_pssh_data.SerializeToString()
if data_serialized != data: # not actually a WidevinePsshData
raise DecodeError()
box = Box.parse(Box.build(dict(
type=b"pssh",
version=0,
flags=0,
system_ID=PSSH.SystemId.Widevine,
init_data=cenc_header
init_data=data_serialized
)))
except DecodeError: # not a widevine cenc header
if "</WRMHEADER>".encode("utf-16-le") in data:
@ -307,16 +307,16 @@ class PSSH:
if self.system_id == PSSH.SystemId.Widevine:
raise ValueError("This is already a Widevine PSSH")
cenc_header = WidevinePsshData()
cenc_header.algorithm = 1 # 0=Clear, 1=AES-CTR
cenc_header.key_ids[:] = [x.bytes for x in self.key_ids]
widevine_pssh_data = WidevinePsshData()
widevine_pssh_data.algorithm = WidevinePsshData.Algorithm.Value("AESCTR")
widevine_pssh_data.key_ids[:] = [x.bytes for x in self.key_ids]
if self.version == 1:
# ensure both cenc header and box has same Key IDs
# v1 uses both this and within init data for basically no reason
self.__key_ids = self.key_ids
self.init_data = cenc_header.SerializeToString()
self.init_data = widevine_pssh_data.SerializeToString()
self.system_id = PSSH.SystemId.Widevine
def to_playready(

View File

@ -86,10 +86,10 @@ class RemoteCdm(Cdm):
server = r.headers.get("Server")
if not server or "pywidevine serve" not in server.lower():
raise ValueError(f"This Remote CDM API does not seem to be a pywidevine serve API ({server}).")
server_version = re.search(r"pywidevine serve v([\d.]+)", server, re.IGNORECASE)
if not server_version:
server_version_re = re.search(r"pywidevine serve v([\d.]+)", server, re.IGNORECASE)
if not server_version_re:
raise ValueError("The pywidevine server API is not stating the version correctly, cannot continue.")
server_version = server_version.group(1)
server_version = server_version_re.group(1)
if server_version < "1.4.3":
raise ValueError(f"This pywidevine serve API version ({server_version}) is not supported.")

View File

@ -1,8 +1,9 @@
import base64
import sys
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union
from aiohttp.typedefs import Handler
from google.protobuf.message import DecodeError
from pywidevine.pssh import PSSH
@ -26,8 +27,8 @@ from pywidevine.exceptions import (InvalidContext, InvalidInitData, InvalidLicen
routes = web.RouteTableDef()
async def _startup(app: web.Application):
app["cdms"]: dict[tuple[str, str], Cdm] = {}
async def _startup(app: web.Application) -> None:
app["cdms"] = {}
app["config"]["devices"] = {
path.stem: path
for x in app["config"]["devices"]
@ -38,7 +39,7 @@ async def _startup(app: web.Application):
raise FileNotFoundError(f"Device file does not exist: {device}")
async def _cleanup(app: web.Application):
async def _cleanup(app: web.Application) -> None:
app["cdms"].clear()
del app["cdms"]
app["config"].clear()
@ -46,7 +47,7 @@ async def _cleanup(app: web.Application):
@routes.get("/")
async def ping(_) -> web.Response:
async def ping(_: Any) -> web.Response:
return web.json_response({
"status": 200,
"message": "Pong!"
@ -211,13 +212,15 @@ async def get_service_certificate(request: web.Request) -> web.Response:
}, status=400)
if service_certificate:
service_certificate = base64.b64encode(service_certificate.SerializeToString()).decode()
service_certificate_b64 = base64.b64encode(service_certificate.SerializeToString()).decode()
else:
service_certificate_b64 = None
return web.json_response({
"status": 200,
"message": "Successfully got the Service Certificate.",
"data": {
"service_certificate": service_certificate
"service_certificate": service_certificate_b64
}
})
@ -366,7 +369,7 @@ async def get_keys(request: web.Request) -> web.Response:
session_id = bytes.fromhex(body["session_id"])
# get key type
key_type = request.match_info["key_type"]
key_type: Optional[str] = request.match_info["key_type"]
if key_type == "ALL":
key_type = None
@ -414,26 +417,24 @@ async def get_keys(request: web.Request) -> web.Response:
@web.middleware
async def authentication(request: web.Request, handler) -> web.Response:
response = None
if request.path != "/":
secret_key = request.headers.get("X-Secret-Key")
if not secret_key:
request.app.logger.debug(f"{request.remote} did not provide authorization.")
response = web.json_response({
"status": "401",
"message": "Secret Key is Empty."
}, status=401)
elif secret_key not in request.app["config"]["users"]:
request.app.logger.debug(f"{request.remote} failed authentication with '{secret_key}'.")
response = web.json_response({
"status": "401",
"message": "Secret Key is Invalid, the Key is case-sensitive."
}, status=401)
async def authentication(request: web.Request, handler: Handler) -> web.Response:
secret_key = request.headers.get("X-Secret-Key")
if response is None:
if request.path != "/" and not secret_key:
request.app.logger.debug(f"{request.remote} did not provide authorization.")
response = web.json_response({
"status": "401",
"message": "Secret Key is Empty."
}, status=401)
elif request.path != "/" and secret_key not in request.app["config"]["users"]:
request.app.logger.debug(f"{request.remote} failed authentication with '{secret_key}'.")
response = web.json_response({
"status": "401",
"message": "Secret Key is Invalid, the Key is case-sensitive."
}, status=401)
else:
try:
response = await handler(request)
response = await handler(request) # type: ignore[assignment]
except web.HTTPException as e:
request.app.logger.error(f"An unexpected error has occurred, {e}")
response = web.json_response({
@ -448,7 +449,7 @@ async def authentication(request: web.Request, handler) -> web.Response:
return response
def run(config: dict, host: Optional[Union[str, web.HostSequence]] = None, port: Optional[int] = None):
def run(config: dict, host: Optional[Union[str, web.HostSequence]] = None, port: Optional[int] = None) -> None:
app = web.Application(middlewares=[authentication])
app.on_startup.append(_startup)
app.on_cleanup.append(_cleanup)