Have set_key_ids method call parse_key_ids directly
This improves user-experience by allowing set_key_ids to accept more types of Key ID formats directly. This also reduces code duplication because the parse function also checks the validity of the Key IDs list for set_key_ids.
This commit is contained in:
parent
52fd5e74ba
commit
cd990e0f4e
|
@ -198,7 +198,7 @@ class PSSH:
|
||||||
# The version must be reinforced ONLY if we have key_id data or there's a possibility of making
|
# The version must be reinforced ONLY if we have key_id data or there's a possibility of making
|
||||||
# a v1 PSSH box, that did not have key_IDs set in the PSSH box.
|
# a v1 PSSH box, that did not have key_IDs set in the PSSH box.
|
||||||
pssh.version = version
|
pssh.version = version
|
||||||
pssh.set_key_ids(cls.parse_key_ids(key_ids))
|
pssh.set_key_ids(key_ids)
|
||||||
|
|
||||||
return pssh
|
return pssh
|
||||||
|
|
||||||
|
@ -384,30 +384,25 @@ class PSSH:
|
||||||
self.init_data = pro
|
self.init_data = pro
|
||||||
self.system_id = PSSH.SystemId.PlayReady
|
self.system_id = PSSH.SystemId.PlayReady
|
||||||
|
|
||||||
def set_key_ids(self, key_ids: list[UUID]) -> None:
|
def set_key_ids(self, key_ids: list[Union[UUID, str, bytes]]) -> None:
|
||||||
"""Overwrite all Key IDs with the specified Key IDs."""
|
"""Overwrite all Key IDs with the specified Key IDs."""
|
||||||
if self.system_id != PSSH.SystemId.Widevine:
|
if self.system_id != PSSH.SystemId.Widevine:
|
||||||
# TODO: Add support for setting the Key IDs in a PlayReady Header
|
# TODO: Add support for setting the Key IDs in a PlayReady Header
|
||||||
raise ValueError(f"Only Widevine PSSH Boxes are supported, not {self.system_id}.")
|
raise ValueError(f"Only Widevine PSSH Boxes are supported, not {self.system_id}.")
|
||||||
|
|
||||||
if not isinstance(key_ids, list):
|
key_id_uuids = self.parse_key_ids(key_ids)
|
||||||
raise TypeError(f"Expecting key_ids to be a list, not {key_ids!r}")
|
|
||||||
|
|
||||||
if not all(isinstance(x, UUID) for x in key_ids):
|
|
||||||
not_uuid = [x for x in key_ids if not isinstance(x, UUID)]
|
|
||||||
raise TypeError(f"All Key IDs in key_ids must be a {UUID}, not {not_uuid}")
|
|
||||||
|
|
||||||
if self.version == 1 or self.__key_ids:
|
if self.version == 1 or self.__key_ids:
|
||||||
# only use v1 box key_ids if version is 1, or it's already being used
|
# only use v1 box key_ids if version is 1, or it's already being used
|
||||||
# this is in case the service stupidly expects it for version 0
|
# this is in case the service stupidly expects it for version 0
|
||||||
self.__key_ids = key_ids
|
self.__key_ids = key_id_uuids
|
||||||
|
|
||||||
cenc_header = WidevinePsshData()
|
cenc_header = WidevinePsshData()
|
||||||
cenc_header.ParseFromString(self.init_data)
|
cenc_header.ParseFromString(self.init_data)
|
||||||
|
|
||||||
cenc_header.key_ids[:] = [
|
cenc_header.key_ids[:] = [
|
||||||
key_id.bytes
|
key_id.bytes
|
||||||
for key_id in key_ids
|
for key_id in key_id_uuids
|
||||||
]
|
]
|
||||||
|
|
||||||
self.init_data = cenc_header.SerializeToString()
|
self.init_data = cenc_header.SerializeToString()
|
||||||
|
|
Loading…
Reference in New Issue