Various typing/linting fixes and improvements
This commit is contained in:
parent
97ec2e1c60
commit
0e6aa1d5e8
|
@ -510,8 +510,7 @@ class Cdm:
|
|||
|
||||
input_file = Path(input_file)
|
||||
output_file = Path(output_file)
|
||||
if temp_dir:
|
||||
temp_dir = Path(temp_dir)
|
||||
temp_dir_ = Path(temp_dir) if temp_dir else None
|
||||
|
||||
if not input_file.is_file():
|
||||
raise FileNotFoundError(f"Input file does not exist, {input_file}")
|
||||
|
@ -545,9 +544,9 @@ class Cdm:
|
|||
])
|
||||
]
|
||||
|
||||
if temp_dir:
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.extend(["--temp_dir", temp_dir])
|
||||
if temp_dir_:
|
||||
temp_dir_.mkdir(parents=True, exist_ok=True)
|
||||
args.extend(["--temp_dir", str(temp_dir_)])
|
||||
|
||||
return subprocess.check_call([executable, *args])
|
||||
|
||||
|
@ -555,8 +554,8 @@ class Cdm:
|
|||
def encrypt_client_id(
|
||||
client_id: ClientIdentification,
|
||||
service_certificate: Union[SignedDrmCertificate, DrmCertificate],
|
||||
key: bytes = None,
|
||||
iv: bytes = None
|
||||
key: Optional[bytes] = None,
|
||||
iv: Optional[bytes] = None
|
||||
) -> EncryptedClientIdentification:
|
||||
"""Encrypt the Client ID with the Service's Privacy Certificate."""
|
||||
privacy_key = key or get_random_bytes(16)
|
||||
|
|
|
@ -199,36 +199,36 @@ class Device:
|
|||
raise ValueError("Device Data does not seem to be a WVD file (v0).")
|
||||
|
||||
if header.version == 1: # v1 to v2
|
||||
data = _Structures.v1.parse(data)
|
||||
data.version = 2 # update version to 2 to allow loading
|
||||
data.flags = Container() # blank flags that may have been used in v1
|
||||
v1_struct = _Structures.v1.parse(data)
|
||||
v1_struct.version = 2 # update version to 2 to allow loading
|
||||
v1_struct.flags = Container() # blank flags that may have been used in v1
|
||||
|
||||
vmp = FileHashes()
|
||||
if data.vmp:
|
||||
if v1_struct.vmp:
|
||||
try:
|
||||
vmp.ParseFromString(data.vmp)
|
||||
if vmp.SerializeToString() != data.vmp:
|
||||
vmp.ParseFromString(v1_struct.vmp)
|
||||
if vmp.SerializeToString() != v1_struct.vmp:
|
||||
raise DecodeError("partial parse")
|
||||
except DecodeError as e:
|
||||
raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}")
|
||||
data.vmp = vmp
|
||||
v1_struct.vmp = vmp
|
||||
|
||||
client_id = ClientIdentification()
|
||||
try:
|
||||
client_id.ParseFromString(data.client_id)
|
||||
if client_id.SerializeToString() != data.client_id:
|
||||
client_id.ParseFromString(v1_struct.client_id)
|
||||
if client_id.SerializeToString() != v1_struct.client_id:
|
||||
raise DecodeError("partial parse")
|
||||
except DecodeError as e:
|
||||
raise DecodeError(f"Failed to parse VMP data as FileHashes, {e}")
|
||||
|
||||
new_vmp_data = data.vmp.SerializeToString()
|
||||
new_vmp_data = v1_struct.vmp.SerializeToString()
|
||||
if client_id.vmp_data and client_id.vmp_data != new_vmp_data:
|
||||
logging.getLogger("migrate").warning("Client ID already has Verified Media Path data")
|
||||
client_id.vmp_data = new_vmp_data
|
||||
data.client_id = client_id.SerializeToString()
|
||||
v1_struct.client_id = client_id.SerializeToString()
|
||||
|
||||
try:
|
||||
data = _Structures.v2.build(data)
|
||||
data = _Structures.v2.build(v1_struct)
|
||||
except ConstructError as e:
|
||||
raise ValueError(f"Migration failed, {e}")
|
||||
|
||||
|
|
|
@ -26,10 +26,8 @@ def main(version: bool, debug: bool) -> None:
|
|||
logging.basicConfig(level=logging.DEBUG if debug else logging.INFO)
|
||||
log = logging.getLogger()
|
||||
|
||||
copyright_years = 2022
|
||||
current_year = datetime.now().year
|
||||
if copyright_years != current_year:
|
||||
copyright_years = f"{copyright_years}-{current_year}"
|
||||
copyright_years = f"2022-{current_year}"
|
||||
|
||||
log.info("pywidevine version %s Copyright (c) %s rlaphoenix", __version__, copyright_years)
|
||||
log.info("https://github.com/rlaphoenix/pywidevine")
|
||||
|
@ -38,15 +36,15 @@ def main(version: bool, debug: bool) -> None:
|
|||
|
||||
|
||||
@main.command(name="license")
|
||||
@click.argument("device", type=Path)
|
||||
@click.argument("pssh", type=str)
|
||||
@click.argument("device_path", type=Path)
|
||||
@click.argument("pssh", type=PSSH)
|
||||
@click.argument("server", type=str)
|
||||
@click.option("-t", "--type", "type_", type=click.Choice(LicenseType.keys(), case_sensitive=False),
|
||||
default="STREAMING",
|
||||
help="License Type to Request.")
|
||||
@click.option("-p", "--privacy", is_flag=True, default=False,
|
||||
help="Use Privacy Mode, off by default.")
|
||||
def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
|
||||
def license_(device_path: Path, pssh: PSSH, server: str, type_: str, privacy: bool) -> None:
|
||||
"""
|
||||
Make a License Request for PSSH to SERVER using DEVICE.
|
||||
It will return a list of all keys within the returned license.
|
||||
|
@ -65,11 +63,8 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
|
|||
"""
|
||||
log = logging.getLogger("license")
|
||||
|
||||
# prepare pssh
|
||||
pssh = PSSH(pssh)
|
||||
|
||||
# load device
|
||||
device = Device.load(device)
|
||||
device = Device.load(device_path)
|
||||
log.info("[+] Loaded Device (%s L%s)", device.system_id, device.security_level)
|
||||
log.debug(device)
|
||||
|
||||
|
@ -84,18 +79,18 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
|
|||
|
||||
if privacy:
|
||||
# get service cert for license server via cert challenge
|
||||
service_cert = requests.post(
|
||||
service_cert_res = requests.post(
|
||||
url=server,
|
||||
data=cdm.service_certificate_challenge
|
||||
)
|
||||
if service_cert.status_code != 200:
|
||||
if service_cert_res.status_code != 200:
|
||||
log.error(
|
||||
"[-] Failed to get Service Privacy Certificate: [%s] %s",
|
||||
service_cert.status_code,
|
||||
service_cert.text
|
||||
service_cert_res.status_code,
|
||||
service_cert_res.text
|
||||
)
|
||||
return
|
||||
service_cert = service_cert.content
|
||||
service_cert = service_cert_res.content
|
||||
provider_id = cdm.set_service_certificate(session_id, service_cert)
|
||||
log.info("[+] Set Service Privacy Certificate: %s", provider_id)
|
||||
log.debug(service_cert)
|
||||
|
@ -107,14 +102,14 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
|
|||
log.debug(challenge)
|
||||
|
||||
# send license challenge
|
||||
licence = requests.post(
|
||||
license_res = requests.post(
|
||||
url=server,
|
||||
data=challenge
|
||||
)
|
||||
if licence.status_code != 200:
|
||||
log.error("[-] Failed to send challenge: [%s] %s", licence.status_code, licence.text)
|
||||
if license_res.status_code != 200:
|
||||
log.error("[-] Failed to send challenge: [%s] %s", license_res.status_code, license_res.text)
|
||||
return
|
||||
licence = licence.content
|
||||
licence = license_res.content
|
||||
log.info("[+] Got License Message")
|
||||
log.debug(licence)
|
||||
|
||||
|
@ -135,7 +130,7 @@ def license_(device: Path, pssh: str, server: str, type_: str, privacy: bool):
|
|||
@click.option("-p", "--privacy", is_flag=True, default=False,
|
||||
help="Use Privacy Mode, off by default.")
|
||||
@click.pass_context
|
||||
def test(ctx: click.Context, device: Path, privacy: bool):
|
||||
def test(ctx: click.Context, device: Path, privacy: bool) -> None:
|
||||
"""
|
||||
Test the CDM code by getting Content Keys for Bitmovin's Art of Motion example.
|
||||
https://bitmovin.com/demos/drm
|
||||
|
@ -161,7 +156,7 @@ def test(ctx: click.Context, device: Path, privacy: bool):
|
|||
# it will print information as it goes to the terminal
|
||||
ctx.invoke(
|
||||
license_,
|
||||
device=device,
|
||||
device_path=device,
|
||||
pssh=pssh,
|
||||
server=license_server,
|
||||
type_=LicenseType.Name(license_type),
|
||||
|
@ -382,10 +377,10 @@ def migrate(ctx: click.Context, path: Path) -> None:
|
|||
|
||||
|
||||
@main.command("serve", short_help="Serve your local CDM and Widevine Devices Remotely.")
|
||||
@click.argument("config", type=Path)
|
||||
@click.argument("config_path", type=Path)
|
||||
@click.option("-h", "--host", type=str, default="127.0.0.1", help="Host to serve from.")
|
||||
@click.option("-p", "--port", type=int, default=8786, help="Port to serve from.")
|
||||
def serve_(config: Path, host: str, port: int):
|
||||
def serve_(config_path: Path, host: str, port: int) -> None:
|
||||
"""
|
||||
Serve your local CDM and Widevine Devices Remotely.
|
||||
|
||||
|
@ -400,5 +395,5 @@ def serve_(config: Path, host: str, port: int):
|
|||
from pywidevine import serve # isort:skip
|
||||
import yaml # isort:skip
|
||||
|
||||
config = yaml.safe_load(config.read_text(encoding="utf8"))
|
||||
config = yaml.safe_load(config_path.read_text(encoding="utf8"))
|
||||
serve.run(config, host, port)
|
||||
|
|
|
@ -82,17 +82,17 @@ class PSSH:
|
|||
box = Box.parse(data)
|
||||
except (IOError, construct.ConstructError): # not a box
|
||||
try:
|
||||
cenc_header = WidevinePsshData()
|
||||
cenc_header.ParseFromString(data)
|
||||
cenc_header = cenc_header.SerializeToString()
|
||||
if cenc_header != data: # not actually a WidevinePsshData
|
||||
widevine_pssh_data = WidevinePsshData()
|
||||
widevine_pssh_data.ParseFromString(data)
|
||||
data_serialized = widevine_pssh_data.SerializeToString()
|
||||
if data_serialized != data: # not actually a WidevinePsshData
|
||||
raise DecodeError()
|
||||
box = Box.parse(Box.build(dict(
|
||||
type=b"pssh",
|
||||
version=0,
|
||||
flags=0,
|
||||
system_ID=PSSH.SystemId.Widevine,
|
||||
init_data=cenc_header
|
||||
init_data=data_serialized
|
||||
)))
|
||||
except DecodeError: # not a widevine cenc header
|
||||
if "</WRMHEADER>".encode("utf-16-le") in data:
|
||||
|
@ -307,16 +307,16 @@ class PSSH:
|
|||
if self.system_id == PSSH.SystemId.Widevine:
|
||||
raise ValueError("This is already a Widevine PSSH")
|
||||
|
||||
cenc_header = WidevinePsshData()
|
||||
cenc_header.algorithm = 1 # 0=Clear, 1=AES-CTR
|
||||
cenc_header.key_ids[:] = [x.bytes for x in self.key_ids]
|
||||
widevine_pssh_data = WidevinePsshData()
|
||||
widevine_pssh_data.algorithm = WidevinePsshData.Algorithm.Value("AESCTR")
|
||||
widevine_pssh_data.key_ids[:] = [x.bytes for x in self.key_ids]
|
||||
|
||||
if self.version == 1:
|
||||
# ensure both cenc header and box has same Key IDs
|
||||
# v1 uses both this and within init data for basically no reason
|
||||
self.__key_ids = self.key_ids
|
||||
|
||||
self.init_data = cenc_header.SerializeToString()
|
||||
self.init_data = widevine_pssh_data.SerializeToString()
|
||||
self.system_id = PSSH.SystemId.Widevine
|
||||
|
||||
def to_playready(
|
||||
|
|
|
@ -86,10 +86,10 @@ class RemoteCdm(Cdm):
|
|||
server = r.headers.get("Server")
|
||||
if not server or "pywidevine serve" not in server.lower():
|
||||
raise ValueError(f"This Remote CDM API does not seem to be a pywidevine serve API ({server}).")
|
||||
server_version = re.search(r"pywidevine serve v([\d.]+)", server, re.IGNORECASE)
|
||||
if not server_version:
|
||||
server_version_re = re.search(r"pywidevine serve v([\d.]+)", server, re.IGNORECASE)
|
||||
if not server_version_re:
|
||||
raise ValueError("The pywidevine server API is not stating the version correctly, cannot continue.")
|
||||
server_version = server_version.group(1)
|
||||
server_version = server_version_re.group(1)
|
||||
if server_version < "1.4.3":
|
||||
raise ValueError(f"This pywidevine serve API version ({server_version}) is not supported.")
|
||||
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import base64
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from aiohttp.typedefs import Handler
|
||||
from google.protobuf.message import DecodeError
|
||||
|
||||
from pywidevine.pssh import PSSH
|
||||
|
@ -26,8 +27,8 @@ from pywidevine.exceptions import (InvalidContext, InvalidInitData, InvalidLicen
|
|||
routes = web.RouteTableDef()
|
||||
|
||||
|
||||
async def _startup(app: web.Application):
|
||||
app["cdms"]: dict[tuple[str, str], Cdm] = {}
|
||||
async def _startup(app: web.Application) -> None:
|
||||
app["cdms"] = {}
|
||||
app["config"]["devices"] = {
|
||||
path.stem: path
|
||||
for x in app["config"]["devices"]
|
||||
|
@ -38,7 +39,7 @@ async def _startup(app: web.Application):
|
|||
raise FileNotFoundError(f"Device file does not exist: {device}")
|
||||
|
||||
|
||||
async def _cleanup(app: web.Application):
|
||||
async def _cleanup(app: web.Application) -> None:
|
||||
app["cdms"].clear()
|
||||
del app["cdms"]
|
||||
app["config"].clear()
|
||||
|
@ -46,7 +47,7 @@ async def _cleanup(app: web.Application):
|
|||
|
||||
|
||||
@routes.get("/")
|
||||
async def ping(_) -> web.Response:
|
||||
async def ping(_: Any) -> web.Response:
|
||||
return web.json_response({
|
||||
"status": 200,
|
||||
"message": "Pong!"
|
||||
|
@ -211,13 +212,15 @@ async def get_service_certificate(request: web.Request) -> web.Response:
|
|||
}, status=400)
|
||||
|
||||
if service_certificate:
|
||||
service_certificate = base64.b64encode(service_certificate.SerializeToString()).decode()
|
||||
service_certificate_b64 = base64.b64encode(service_certificate.SerializeToString()).decode()
|
||||
else:
|
||||
service_certificate_b64 = None
|
||||
|
||||
return web.json_response({
|
||||
"status": 200,
|
||||
"message": "Successfully got the Service Certificate.",
|
||||
"data": {
|
||||
"service_certificate": service_certificate
|
||||
"service_certificate": service_certificate_b64
|
||||
}
|
||||
})
|
||||
|
||||
|
@ -366,7 +369,7 @@ async def get_keys(request: web.Request) -> web.Response:
|
|||
session_id = bytes.fromhex(body["session_id"])
|
||||
|
||||
# get key type
|
||||
key_type = request.match_info["key_type"]
|
||||
key_type: Optional[str] = request.match_info["key_type"]
|
||||
if key_type == "ALL":
|
||||
key_type = None
|
||||
|
||||
|
@ -414,26 +417,24 @@ async def get_keys(request: web.Request) -> web.Response:
|
|||
|
||||
|
||||
@web.middleware
|
||||
async def authentication(request: web.Request, handler) -> web.Response:
|
||||
response = None
|
||||
if request.path != "/":
|
||||
async def authentication(request: web.Request, handler: Handler) -> web.Response:
|
||||
secret_key = request.headers.get("X-Secret-Key")
|
||||
if not secret_key:
|
||||
|
||||
if request.path != "/" and not secret_key:
|
||||
request.app.logger.debug(f"{request.remote} did not provide authorization.")
|
||||
response = web.json_response({
|
||||
"status": "401",
|
||||
"message": "Secret Key is Empty."
|
||||
}, status=401)
|
||||
elif secret_key not in request.app["config"]["users"]:
|
||||
elif request.path != "/" and secret_key not in request.app["config"]["users"]:
|
||||
request.app.logger.debug(f"{request.remote} failed authentication with '{secret_key}'.")
|
||||
response = web.json_response({
|
||||
"status": "401",
|
||||
"message": "Secret Key is Invalid, the Key is case-sensitive."
|
||||
}, status=401)
|
||||
|
||||
if response is None:
|
||||
else:
|
||||
try:
|
||||
response = await handler(request)
|
||||
response = await handler(request) # type: ignore[assignment]
|
||||
except web.HTTPException as e:
|
||||
request.app.logger.error(f"An unexpected error has occurred, {e}")
|
||||
response = web.json_response({
|
||||
|
@ -448,7 +449,7 @@ async def authentication(request: web.Request, handler) -> web.Response:
|
|||
return response
|
||||
|
||||
|
||||
def run(config: dict, host: Optional[Union[str, web.HostSequence]] = None, port: Optional[int] = None):
|
||||
def run(config: dict, host: Optional[Union[str, web.HostSequence]] = None, port: Optional[int] = None) -> None:
|
||||
app = web.Application(middlewares=[authentication])
|
||||
app.on_startup.append(_startup)
|
||||
app.on_cleanup.append(_cleanup)
|
||||
|
|
Loading…
Reference in New Issue