diff --git a/pywidevine/cdm.py b/pywidevine/cdm.py index 927136e..177da4f 100644 --- a/pywidevine/cdm.py +++ b/pywidevine/cdm.py @@ -60,7 +60,7 @@ class Cdm: NUM_OF_SESSIONS = 0 MAX_NUM_OF_SESSIONS = 50 # most common limit - def __init__(self, device: Device, pssh: Union[Container, bytes, str], raw: bool = False): + def __init__(self, device: Device, pssh: Union[Container, bytes, str]): """ Open a Widevine Content Decryption Module (CDM) session. @@ -69,9 +69,6 @@ class Cdm: more device-specific information. pssh: Protection System Specific Header Box or Init Data. This should be a compliant mp4 pssh box, or just the init data (Widevine Cenc Header). - raw: This should be set to True if the PSSH data provided is arbitrary data. - E.g., a PSSH Box where the init data is not a Widevine Cenc Header, or - is simply arbitrary data. Devices have a limit on how many sessions can be open and active concurrently. The limit is different for each device and security level, most commonly 50. @@ -92,11 +89,7 @@ class Cdm: self.NUM_OF_SESSIONS += 1 self.device = device - self.init_data = pssh - - if not raw: - # we only want the init_data of the pssh box - self.init_data = PSSH.get_as_box(pssh).init_data + self.init_data = PSSH.get_as_box(pssh).init_data self.session_id = get_random_bytes(16) self.service_certificate: Optional[DrmCertificate] = None diff --git a/pywidevine/pssh.py b/pywidevine/pssh.py index 2ef4db9..fa261ab 100644 --- a/pywidevine/pssh.py +++ b/pywidevine/pssh.py @@ -1,9 +1,11 @@ from __future__ import annotations import base64 +import binascii from typing import Union from uuid import UUID +import construct from construct import Container from google.protobuf.message import DecodeError from lxml import etree @@ -78,36 +80,59 @@ class PSSH: return box @staticmethod - def get_as_box(data: Union[Container, bytes, str]) -> Container: + def get_as_box(data: Union[Container, bytes, str], strict: bool = False) -> Container: """ - Get the possibly arbitrary data as a parsed PSSH mp4 box. - If the data is just Widevine PSSH Data (init data) then it will be crafted - into a new PSSH mp4 box. - If the data could not be recognized as a PSSH box of some form of encoding - it will raise a ValueError. + Get possibly arbitrary data as a parsed PSSH mp4 box. + + Parameters: + data: PSSH mp4 box, Widevine Cenc Header (init data), or arbitrary data to + parse or craft into a PSSH mp4 box. + strict: Do not return a PSSH box for arbitrary data. Require the data to be + at least a PSSH mp4 box, or a Widevine Cenc Header. + + Raises: + ValueError: If the data is empty, or an unexpected type. + binascii.Error: If the data could not be decoded as Base64 if provided + as a string. + construct.ConstructError: If the data could not be parsed as a PSSH mp4 box + nor a Widevine Cenc Header while strict=True. """ - if isinstance(data, str): - data = base64.b64decode(data) - if isinstance(data, bytes): - if base64.b64encode(data).startswith(b"CAES"): # likely widevine pssh data - try: - cenc_header = WidevinePsshData() - cenc_header.ParseFromString(data) - except DecodeError: - # not actually init data after all - pass - else: - data = Box.parse(Box.build(dict( - type=b"pssh", - version=0, - flags=0, - system_ID=PSSH.SystemId.Widevine, - init_data=cenc_header.SerializeToString() - ))) - data = Box.parse(data) + if not data: + raise ValueError("Data must not be empty.") + if isinstance(data, Container): return data - raise ValueError(f"Unrecognized PSSH data: {data!r}") + + if isinstance(data, str): + try: + data = base64.b64decode(data) + except (binascii.Error, binascii.Incomplete) as e: + raise binascii.Error(f"Could not decode data as Base64, {e}") + + if isinstance(data, bytes): + try: + data = Box.parse(data) + except construct.ConstructError: + if strict: + try: + cenc_header = WidevinePsshData() + if cenc_header.MergeFromString(data) < len(data): + raise DecodeError() + except DecodeError: + raise DecodeError(f"Could not parse data as a PSSH mp4 box nor a Widevine Cenc Header.") + else: + data = cenc_header.SerializeToString() + data = Box.parse(Box.build(dict( + type=b"pssh", + version=0, + flags=0, + system_ID=PSSH.SystemId.Widevine, + init_data=data + ))) + else: + raise ValueError(f"Data is an unexpected type, expected bytes got {data!r}.") + + return data @staticmethod def get_key_ids(box: Container) -> list[UUID]: