diff --git a/pywidevine/pssh.py b/pywidevine/pssh.py index 29e651e..2eae1fe 100644 --- a/pywidevine/pssh.py +++ b/pywidevine/pssh.py @@ -170,25 +170,6 @@ class PSSH: if init_data is None and key_ids is None: raise ValueError("Version 1 PSSH boxes must use either init_data or key_ids but neither were provided") - if key_ids is not None: - # ensure key_ids are UUID, supports hex, base64, and bytes - if not all(isinstance(x, (UUID, bytes, str)) for x in key_ids): - not_bytes = [x for x in key_ids if not isinstance(x, (UUID, bytes, str))] - raise TypeError( - "Expected all of key_ids to be a UUID, hex, base64, or bytes, but one or more are not, " - f"{not_bytes!r}" - ) - key_ids = [ - UUID(bytes=key_id_b) - for key_id in key_ids - for key_id_b in [ - key_id.bytes if isinstance(key_id, UUID) else - bytes.fromhex(key_id) if all(c in string.hexdigits for c in key_id) else - base64.b64decode(key_id) if isinstance(key_id, str) else - key_id - ] - ] - if init_data is not None: if isinstance(init_data, WidevinePsshData): init_data = init_data.SerializeToString() @@ -217,7 +198,7 @@ class PSSH: # The version must be reinforced ONLY if we have key_id data or there's a possibility of making # a v1 PSSH box, that did not have key_IDs set in the PSSH box. pssh.version = version - pssh.set_key_ids(key_ids) + pssh.set_key_ids(cls.parse_key_ids(key_ids)) return pssh @@ -430,3 +411,32 @@ class PSSH: ] self.init_data = cenc_header.SerializeToString() + + @staticmethod + def parse_key_ids(key_ids: list[Union[UUID, str, bytes]]) -> list[UUID]: + """ + Parse a list of Key IDs in hex, base64, or bytes to UUIDs. + + Raises TypeError if `key_ids` is not a list, or the list contains one + or more items that are not a UUID, str, or bytes object. + """ + if not isinstance(key_ids, list): + raise TypeError(f"Expected key_ids to be a list, not {key_ids!r}") + + if not all(isinstance(x, (UUID, str, bytes)) for x in key_ids): + raise TypeError("Some items of key_ids are not a UUID, str, or bytes. Unsure how to continue...") + + uuids = [ + UUID(bytes=key_id_b) + for key_id in key_ids + for key_id_b in [ + key_id.bytes if isinstance(key_id, UUID) else + ( + bytes.fromhex(key_id) if all(c in string.hexdigits for c in key_id) else + base64.b64decode(key_id) + ) if isinstance(key_id, str) else + key_id + ] + ] + + return uuids