diff --git a/pywidevine/cdm.py b/pywidevine/cdm.py index 839903a..c962b37 100644 --- a/pywidevine/cdm.py +++ b/pywidevine/cdm.py @@ -1,10 +1,11 @@ import base64 +import binascii import random import subprocess import sys import time from pathlib import Path -from typing import Union, Optional +from typing import Union, Container, Optional from uuid import UUID from Crypto.Cipher import AES, PKCS1_OAEP @@ -13,15 +14,17 @@ from Crypto.PublicKey import RSA from Crypto.Random import get_random_bytes from Crypto.Signature import pss from Crypto.Util import Padding -from construct import Container from google.protobuf.message import DecodeError -from pywidevine.utils import get_binary_path -from pywidevine.license_protocol_pb2 import LicenseType, SignedMessage, LicenseRequest, ProtocolVersion, \ - SignedDrmCertificate, DrmCertificate, EncryptedClientIdentification, ClientIdentification, License from pywidevine.device import Device +from pywidevine.exceptions import TooManySessions, InvalidSession, InvalidLicenseType, SignatureMismatch, \ + InvalidInitData, InvalidLicenseMessage, NoKeysLoaded, InvalidContext from pywidevine.key import Key +from pywidevine.license_protocol_pb2 import DrmCertificate, SignedMessage, SignedDrmCertificate, LicenseType, \ + LicenseRequest, ProtocolVersion, ClientIdentification, EncryptedClientIdentification, License from pywidevine.pssh import PSSH +from pywidevine.session import Session +from pywidevine.utils import get_binary_path class Cdm: @@ -57,45 +60,47 @@ class Cdm: root_cert = DrmCertificate() root_cert.ParseFromString(root_signed_cert.drm_certificate) - NUM_OF_SESSIONS = 0 MAX_NUM_OF_SESSIONS = 50 # most common limit - def __init__(self, device: Device, init_data: Union[Container, bytes, str]): + def __init__(self, device: Device): + """Initialize a Widevine Content Decryption Module (CDM).""" + if not device: + raise ValueError("A Widevine Device must be provided.") + self.device = device + + self._sessions: dict[bytes, Session] = {} + + def open(self) -> bytes: """ Open a Widevine Content Decryption Module (CDM) session. - Parameters: - device: Widevine Device containing the Client ID, Device Private Key, and - more device-specific information. - init_data: Widevine Cenc Header (Init Data) or a Protection System Specific - Header Box to take the init data from. - - Devices have a limit on how many sessions can be open and active concurrently. - The limit is different for each device and security level, most commonly 50. - This limit is handled by the OEM Crypto API. Multiple sessions can be open at - a time and sessions should be closed when no longer needed. + Raises: + TooManySessions: If the session cannot be opened as limit has been reached. """ - if not device: - raise ValueError("A Widevine Device must be provided.") - if not init_data: - raise ValueError("Init Data (or a PSSH) must be provided.") + if len(self._sessions) > self.MAX_NUM_OF_SESSIONS: + raise TooManySessions(f"Too many Sessions open ({self.MAX_NUM_OF_SESSIONS}).") - if self.NUM_OF_SESSIONS >= self.MAX_NUM_OF_SESSIONS: - raise ValueError( - f"Too many Sessions open {self.NUM_OF_SESSIONS}/{self.MAX_NUM_OF_SESSIONS}. " - f"Close some Sessions to be able to open more." - ) + session = Session() + self._sessions[session.id] = session - self.NUM_OF_SESSIONS += 1 + return session.id - self.device = device - self.init_data = PSSH.get_as_box(init_data).init_data + def close(self, session_id: bytes) -> None: + """ + Close a Widevine Content Decryption Module (CDM) session. - self.session_id = get_random_bytes(16) - self.service_certificate: Optional[DrmCertificate] = None - self.context: dict[bytes, tuple[bytes, bytes]] = {} + Parameters: + session_id: Session identifier. - def set_service_certificate(self, certificate: Union[bytes, str]) -> str: + Raises: + InvalidSession: If the Session identifier is invalid. + """ + session = self._sessions.get(session_id) + if not session: + raise InvalidSession(f"Session identifier {session_id!r} is invalid.") + del self._sessions[session_id] + + def set_service_certificate(self, session_id: bytes, certificate: Union[bytes, str]) -> str: """ Set a Service Privacy Certificate for Privacy Mode. (optional but recommended) @@ -108,19 +113,31 @@ class Cdm: containing a SignedDrmCertificate. Parameters: + session_id: Session identifier. certificate: SignedDrmCertificate (or SignedMessage containing one) in Base64 or Bytes form obtained from the Service. Some services have their own, but most use the common privacy cert, (common_privacy_cert). Raises: + InvalidSession: If the Session identifier is invalid. DecodeError: If the certificate could not be parsed as a SignedDrmCertificate nor a SignedMessage containing a SignedDrmCertificate. - ValueError: If the SignedDrmCertificate signature is invalid. + SignatureMismatch: If the Signature of the SignedDrmCertificate does not + match the underlying DrmCertificate. Returns the Service Provider ID of the verified DrmCertificate if successful. """ + session = self._sessions.get(session_id) + if not session: + raise InvalidSession(f"Session identifier {session_id!r} is invalid.") + if isinstance(certificate, str): - certificate = base64.b64decode(certificate) # assuming base64 + try: + certificate = base64.b64decode(certificate) # assuming base64 + except binascii.Error: + raise DecodeError("Could not decode certificate string as Base64, expected bytes.") + elif not isinstance(certificate, bytes): + raise DecodeError(f"Expecting Certificate to be bytes, not {certificate!r}") signed_message = SignedMessage() signed_drm_certificate = SignedDrmCertificate() @@ -145,30 +162,64 @@ class Cdm: signature=signed_drm_certificate.signature ) except (ValueError, TypeError): - raise ValueError("Signature Mismatch on SignedDrmCertificate, rejecting certificate") + raise SignatureMismatch("Signature Mismatch on SignedDrmCertificate, rejecting certificate") else: drm_certificate = DrmCertificate() drm_certificate.ParseFromString(signed_drm_certificate.drm_certificate) - self.service_certificate = drm_certificate - return self.service_certificate.provider_id + session.service_certificate = drm_certificate + return drm_certificate.provider_id - def get_license_challenge(self, type_: Union[int, str] = LicenseType.STREAMING, privacy_mode: bool = True) -> bytes: + def get_license_challenge( + self, + session_id: bytes, + init_data: Union[Container, bytes, str], + type_: Union[int, str] = LicenseType.STREAMING, + privacy_mode: bool = True + ) -> bytes: """ - Get a License Challenge to send to a License Server. + Get a License Request (Challenge) to send to a License Server. Parameters: - type_: Type of License you wish to exchange, often `STREAMING`. - The `OFFLINE` Licenses are for Offline licensing of Downloaded content. + session_id: Session identifier. + init_data: Widevine Cenc Header (Init Data) or a Protection System Specific + Header Box to take the init data from. + type_: Type of License you wish to exchange, often `STREAMING`. The `OFFLINE` + Licenses are for Offline licensing of Downloaded content. privacy_mode: Encrypt the Client ID using the Privacy Certificate. If the privacy certificate is not set yet, this does nothing. + Raises: + InvalidSession: If the Session identifier is invalid. + InvalidInitData: If the Init Data (or PSSH box) provided is invalid. + InvalidLicenseType: If the type_ parameter value is not a License Type. It + must be a LicenseType enum, or a string/int representing the enum's keys + or values. + Returns a SignedMessage containing a LicenseRequest message. It's signed with the Private Key of the device provision. """ - request_id = get_random_bytes(16) + session = self._sessions.get(session_id) + if not session: + raise InvalidSession(f"Session identifier {session_id!r} is invalid.") - if isinstance(type_, str): - type_ = LicenseType.Value(type_) + if not init_data: + raise InvalidInitData("The init_data must not be empty.") + try: + init_data = PSSH.get_as_box(init_data).init_data + except (ValueError, binascii.Error, DecodeError) as e: + raise InvalidInitData(str(e)) + + try: + if isinstance(type_, int): + LicenseType.Name(int(type_)) + elif isinstance(type_, str): + type_ = LicenseType.Value(type_) + elif not isinstance(type_, LicenseType): + raise InvalidLicenseType() + except ValueError: + raise InvalidLicenseType(f"License Type {type_!r} is invalid") + + request_id = get_random_bytes(16) license_request = LicenseRequest() license_request.type = LicenseRequest.RequestType.Value("NEW") @@ -176,49 +227,76 @@ class Cdm: license_request.protocol_version = ProtocolVersion.Value("VERSION_2_1") license_request.key_control_nonce = random.randrange(1, 2 ** 31) - license_request.content_id.widevine_pssh_data.pssh_data.append(self.init_data) + license_request.content_id.widevine_pssh_data.pssh_data.append(init_data) license_request.content_id.widevine_pssh_data.license_type = type_ license_request.content_id.widevine_pssh_data.request_id = request_id - if self.service_certificate and privacy_mode: + if session.service_certificate and privacy_mode: # encrypt the client id for privacy mode license_request.encrypted_client_id.CopyFrom(self.encrypt_client_id( client_id=self.device.client_id, - service_certificate=self.service_certificate + service_certificate=session.service_certificate )) else: license_request.client_id.CopyFrom(self.device.client_id) license_message = SignedMessage() - license_message.type = SignedMessage.MessageType.Value("LICENSE_REQUEST") + license_message.type = SignedMessage.MessageType.LICENSE_REQUEST license_message.msg = license_request.SerializeToString() license_message.signature = pss. \ new(self.device.private_key). \ sign(SHA1.new(license_message.msg)) - self.context[request_id] = self.derive_context(license_message.msg) + session.context[request_id] = self.derive_context(license_message.msg) return license_message.SerializeToString() - def parse_license(self, license_message: Union[bytes, str]) -> list[Key]: + def parse_license(self, session_id: bytes, license_message: Union[SignedMessage, bytes, str]) -> None: + """ + Load Keys from a License Message from a License Server Response. + + Parameters: + session_id: Session identifier. + license_message: A SignedMessage containing a License message. + + Raises: + InvalidSession: If the Session identifier is invalid. + InvalidLicenseMessage: The License message could not be decoded as a Signed + Message or License message. + InvalidContext: If the Session has no Context Data. This is likely to happen + if the License Challenge was not made by this CDM instance, or was not + by this CDM at all. It could also happen if the Session is closed after + calling parse_license but not before it got the context data. + SignatureMismatch: If the Signature of the License SignedMessage does not + match the underlying License. + """ + session = self._sessions.get(session_id) + if not session: + raise InvalidSession(f"Session identifier {session_id!r} is invalid.") + if not license_message: - raise ValueError("Cannot parse an empty license_message as a SignedMessage") + raise InvalidLicenseMessage("Cannot parse an empty license_message") if isinstance(license_message, str): - license_message = base64.b64decode(license_message) + try: + license_message = base64.b64decode(license_message) + except (binascii.Error, binascii.Incomplete) as e: + raise InvalidLicenseMessage(f"Could not decode license_message as Base64, {e}") + if isinstance(license_message, bytes): signed_message = SignedMessage() try: signed_message.ParseFromString(license_message) - except DecodeError: - raise ValueError("Failed to parse license_message as a SignedMessage") + except DecodeError as e: + raise InvalidLicenseMessage(f"Could not parse license_message as a SignedMessage, {e}") license_message = signed_message + if not isinstance(license_message, SignedMessage): - raise ValueError(f"Expecting license_response to be a SignedMessage, got {license_message!r}") + raise InvalidLicenseMessage(f"Expecting license_response to be a SignedMessage, got {license_message!r}") if license_message.type != SignedMessage.MessageType.LICENSE: - raise ValueError( + raise InvalidLicenseMessage( f"Expecting a LICENSE message, not a " f"'{SignedMessage.MessageType.Name(license_message.type)}' message." ) @@ -226,9 +304,9 @@ class Cdm: licence = License() licence.ParseFromString(license_message.msg) - context = self.context.get(licence.id.request_id) + context = session.context.get(licence.id.request_id) if not context: - raise ValueError("Cannot parse a license message without first making a license request") + raise InvalidContext("Cannot parse a license message without first making a license request") session_key = PKCS1_OAEP. \ new(self.device.private_key). \ @@ -242,60 +320,97 @@ class Cdm: digest() if license_message.signature != computed_signature: - raise ValueError("Signature Mismatch on License Message, rejecting license") + raise SignatureMismatch("Signature Mismatch on License Message, rejecting license") - return [ + session.keys = [ Key.from_key_container(key, enc_key) for key in licence.key ] - @staticmethod - def decrypt(content_keys: dict[UUID, str], input_: Path, output: Path, temp: Optional[Path] = None): + def decrypt( + self, + session_id: bytes, + input_file: Union[Path, str], + output_file: Union[Path, str], + temp_dir: Optional[Union[Path, str]] = None, + exists_ok: bool = False + ): """ Decrypt a Widevine-encrypted file using Shaka-packager. Shaka-packager is much more stable than mp4decrypt. + Parameters: + session_id: Session identifier. + input_file: File to be decrypted with Session's currently loaded keys. + output_file: Location to save decrypted file. + temp_dir: Directory to store temporary data while decrypting. + exists_ok: Allow overwriting the output_file if it exists. + Raises: - EnvironmentError if the Shaka Packager executable could not be found. - ValueError if the track has not yet been downloaded. - SubprocessError if Shaka Packager returned a non-zero exit code. + ValueError: If the input or output paths have not been supplied or are + invalid. + FileNotFoundError: If the input file path does not exist. + FileExistsError: If the output file path already exists. Ignored if exists_ok + is set to True. + NoKeysLoaded: No License was parsed for this Session, No Keys available. + EnvironmentError: If the shaka-packager executable could not be found. + subprocess.CalledProcessError: If the shaka-packager call returned a non-zero + exit code. """ - if not content_keys: - raise ValueError("Cannot decrypt without any Content Keys") - if not input_: + if not input_file: raise ValueError("Cannot decrypt nothing, specify an input path") - if not output: + if not output_file: raise ValueError("Cannot decrypt nowhere, specify an output path") + if not isinstance(input_file, (Path, str)): + raise ValueError(f"Expecting input_file to be a Path or str, got {input_file!r}") + if not isinstance(output_file, (Path, str)): + raise ValueError(f"Expecting output_file to be a Path or str, got {output_file!r}") + if not isinstance(temp_dir, (Path, str)) and temp_dir is not None: + raise ValueError(f"Expecting temp_dir to be a Path or str, got {temp_dir!r}") + + input_file = Path(input_file) + output_file = Path(output_file) + if temp_dir: + temp_dir = Path(temp_dir) + + if not input_file.is_file(): + raise FileNotFoundError(f"Input file does not exist, {input_file}") + if output_file.is_file() and not exists_ok: + raise FileExistsError(f"Output file already exists, {output_file}") + + session = self._sessions.get(session_id) + if not session: + raise InvalidSession(f"Session identifier {session_id!r} is invalid.") + + if not session.keys: + raise NoKeysLoaded("No Keys are loaded yet, cannot decrypt") + platform = {"win32": "win", "darwin": "osx"}.get(sys.platform, sys.platform) executable = get_binary_path("shaka-packager", f"packager-{platform}", f"packager-{platform}-x64") if not executable: raise EnvironmentError("Shaka Packager executable not found but is required") args = [ - f"input={input_},stream=0,output={output}", - "--enable_raw_key_decryption", "--keys", - ",".join([ - *[ - f"label={i}:key_id={kid.hex}:key={key.lower()}" - for i, (kid, key) in enumerate(content_keys.items()) - ], - *[ - # Apple TV+ needs this as their files do not use the KID supplied in the manifest - f"label={i}:key_id=00000000000000000000000000000000:key={key.lower()}" - for i, (kid, key) in enumerate(content_keys.items(), len(content_keys)) + f"input={input_file},stream=0,output={output_file}", + "--enable_raw_key_decryption", + "--keys", ",".join([ + label + for i, key in enumerate(session.keys) + for label in [ + f"label=1_{i}:key_id={key.kid.hex}:key={key.key.hex()}", + # some services need the KID blanked, e.g., Apple TV+ + f"label=2_{i}:key_id={'0' * 32}:key={key.key.hex()}" ] + if key.type == "CONTENT" ]) ] - if temp: - temp.mkdir(parents=True, exist_ok=True) - args.extend(["--temp_dir", temp]) + if temp_dir: + temp_dir.mkdir(parents=True, exist_ok=True) + args.extend(["--temp_dir", temp_dir]) - try: - subprocess.check_call([executable, *args]) - except subprocess.CalledProcessError as e: - raise subprocess.SubprocessError(f"Failed to Decrypt! Shaka Packager Error: {e}") + subprocess.check_call([executable, *args]) @staticmethod def encrypt_client_id( @@ -365,7 +480,8 @@ class Cdm: """ def _derive(session_key: bytes, context: bytes, counter: int) -> bytes: - return CMAC.new(session_key, ciphermod=AES). \ + return CMAC. \ + new(session_key, ciphermod=AES). \ update(counter.to_bytes(1, "big") + context). \ digest()