Use free-form strings for stream labels (track types)

Change-Id: I38489acbdfaf4bb491635fdc7f6b0cab77a53574
This commit is contained in:
KongQun Yang 2017-06-13 14:54:12 -07:00
parent 05a5a41969
commit 1aeedd102e
15 changed files with 70 additions and 153 deletions

View File

@ -21,7 +21,8 @@ Status FixedKeySource::FetchKeys(EmeInitDataType init_data_type,
return Status::OK; return Status::OK;
} }
Status FixedKeySource::GetKey(TrackType track_type, EncryptionKey* key) { Status FixedKeySource::GetKey(const std::string& stream_label,
EncryptionKey* key) {
DCHECK(key); DCHECK(key);
DCHECK(encryption_key_); DCHECK(encryption_key_);
*key = *encryption_key_; *key = *encryption_key_;
@ -43,7 +44,7 @@ Status FixedKeySource::GetKey(const std::vector<uint8_t>& key_id,
} }
Status FixedKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index, Status FixedKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type, const std::string& stream_label,
EncryptionKey* key) { EncryptionKey* key) {
// Create a copy of the key. // Create a copy of the key.
*key = *encryption_key_; *key = *encryption_key_;

View File

@ -32,11 +32,11 @@ class FixedKeySource : public KeySource {
/// @{ /// @{
Status FetchKeys(EmeInitDataType init_data_type, Status FetchKeys(EmeInitDataType init_data_type,
const std::vector<uint8_t>& init_data) override; const std::vector<uint8_t>& init_data) override;
Status GetKey(TrackType track_type, EncryptionKey* key) override; Status GetKey(const std::string& stream_label, EncryptionKey* key) override;
Status GetKey(const std::vector<uint8_t>& key_id, Status GetKey(const std::vector<uint8_t>& key_id,
EncryptionKey* key) override; EncryptionKey* key) override;
Status GetCryptoPeriodKey(uint32_t crypto_period_index, Status GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type, const std::string& stream_label,
EncryptionKey* key) override; EncryptionKey* key) override;
/// @} /// @}

View File

@ -53,7 +53,7 @@ TEST(FixedKeySourceTest, CreateFromHexStrings_Succes) {
ASSERT_NE(nullptr, key_source); ASSERT_NE(nullptr, key_source);
EncryptionKey key; EncryptionKey key;
ASSERT_OK(key_source->GetKey(KeySource::TRACK_TYPE_SD, &key)); ASSERT_OK(key_source->GetKey("SomeStreamLabel", &key));
EXPECT_HEX_EQ(kKeyIdHex, key.key_id); EXPECT_HEX_EQ(kKeyIdHex, key.key_id);
EXPECT_HEX_EQ(kKeyHex, key.key); EXPECT_HEX_EQ(kKeyHex, key.key);
@ -70,7 +70,7 @@ TEST(FixedKeySourceTest, CreateFromHexStrings_EmptyPssh) {
ASSERT_NE(nullptr, key_source); ASSERT_NE(nullptr, key_source);
EncryptionKey key; EncryptionKey key;
ASSERT_OK(key_source->GetKey(KeySource::TRACK_TYPE_SD, &key)); ASSERT_OK(key_source->GetKey("SomeStreamLabel", &key));
EXPECT_HEX_EQ(kKeyIdHex, key.key_id); EXPECT_HEX_EQ(kKeyIdHex, key.key_id);
EXPECT_HEX_EQ(kKeyHex, key.key); EXPECT_HEX_EQ(kKeyHex, key.key);

View File

@ -12,47 +12,12 @@ namespace shaka {
namespace media { namespace media {
EncryptionKey::EncryptionKey() {} EncryptionKey::EncryptionKey() {}
EncryptionKey::~EncryptionKey() {} EncryptionKey::~EncryptionKey() {}
KeySource::~KeySource() {}
KeySource::TrackType KeySource::GetTrackTypeFromString(
const std::string& track_type_string) {
if (track_type_string == "SD")
return TRACK_TYPE_SD;
if (track_type_string == "HD")
return TRACK_TYPE_HD;
if (track_type_string == "UHD1")
return TRACK_TYPE_UHD1;
if (track_type_string == "UHD2")
return TRACK_TYPE_UHD2;
if (track_type_string == "AUDIO")
return TRACK_TYPE_AUDIO;
if (track_type_string == "UNSPECIFIED")
return TRACK_TYPE_UNSPECIFIED;
LOG(WARNING) << "Unexpected track type: " << track_type_string;
return TRACK_TYPE_UNKNOWN;
}
std::string KeySource::TrackTypeToString(TrackType track_type) {
switch (track_type) {
case TRACK_TYPE_SD:
return "SD";
case TRACK_TYPE_HD:
return "HD";
case TRACK_TYPE_UHD1:
return "UHD1";
case TRACK_TYPE_UHD2:
return "UHD2";
case TRACK_TYPE_AUDIO:
return "AUDIO";
default:
NOTIMPLEMENTED() << "Unknown track type: " << track_type;
return "UNKNOWN";
}
}
KeySource::KeySource() {} KeySource::KeySource() {}
KeySource::~KeySource() {}
} // namespace media } // namespace media
} // namespace shaka } // namespace shaka

View File

@ -44,17 +44,6 @@ struct EncryptionKey {
/// KeySource is responsible for encryption key acquisition. /// KeySource is responsible for encryption key acquisition.
class KeySource { class KeySource {
public: public:
enum TrackType {
TRACK_TYPE_UNKNOWN = 0,
TRACK_TYPE_SD = 1,
TRACK_TYPE_HD = 2,
TRACK_TYPE_UHD1 = 3,
TRACK_TYPE_UHD2 = 4,
TRACK_TYPE_AUDIO = 5,
TRACK_TYPE_UNSPECIFIED = 6,
NUM_VALID_TRACK_TYPES = 6
};
KeySource(); KeySource();
virtual ~KeySource(); virtual ~KeySource();
@ -65,12 +54,13 @@ class KeySource {
virtual Status FetchKeys(EmeInitDataType init_data_type, virtual Status FetchKeys(EmeInitDataType init_data_type,
const std::vector<uint8_t>& init_data) = 0; const std::vector<uint8_t>& init_data) = 0;
/// Get encryption key of the specified track type. /// Get encryption key of the specified stream label.
/// @param track_type is the type of track for which retrieving the key. /// @param stream_label is the label of stream for which retrieving the key.
/// @param key is a pointer to the EncryptionKey which will hold the retrieved /// @param key is a pointer to the EncryptionKey which will hold the retrieved
/// key. Owner retains ownership, and may not be NULL. /// key. Owner retains ownership, and may not be NULL.
/// @return OK on success, an error status otherwise. /// @return OK on success, an error status otherwise.
virtual Status GetKey(TrackType track_type, EncryptionKey* key) = 0; virtual Status GetKey(const std::string& stream_label,
EncryptionKey* key) = 0;
/// Get the encryption key specified by the CENC key ID. /// Get the encryption key specified by the CENC key ID.
/// @param key_id is the unique identifier for the key being retreived. /// @param key_id is the unique identifier for the key being retreived.
@ -83,20 +73,14 @@ class KeySource {
/// Get encryption key of the specified track type at the specified index. /// Get encryption key of the specified track type at the specified index.
/// @param crypto_period_index is the sequence number of the key rotation /// @param crypto_period_index is the sequence number of the key rotation
/// period for which the key is being retrieved. /// period for which the key is being retrieved.
/// @param track_type is the type of track for which retrieving the key. /// @param stream_label is the label of stream for which retrieving the key.
/// @param key is a pointer to the EncryptionKey which will hold the retrieved /// @param key is a pointer to the EncryptionKey which will hold the retrieved
/// key. Owner retains ownership, and may not be NULL. /// key. Owner retains ownership, and may not be NULL.
/// @return OK on success, an error status otherwise. /// @return OK on success, an error status otherwise.
virtual Status GetCryptoPeriodKey(uint32_t crypto_period_index, virtual Status GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type, const std::string& stream_label,
EncryptionKey* key) = 0; EncryptionKey* key) = 0;
/// Convert string representation of track type to enum representation.
static TrackType GetTrackTypeFromString(const std::string& track_type_string);
/// Convert TrackType to string.
static std::string TrackTypeToString(TrackType track_type);
private: private:
DISALLOW_COPY_AND_ASSIGN(KeySource); DISALLOW_COPY_AND_ASSIGN(KeySource);
}; };

View File

@ -316,9 +316,10 @@ Status PlayReadyKeySource::FetchKeys(EmeInitDataType init_data_type,
return Status::OK; return Status::OK;
} }
Status PlayReadyKeySource::GetKey(TrackType track_type, EncryptionKey* key) { Status PlayReadyKeySource::GetKey(const std::string& stream_label,
EncryptionKey* key) {
// TODO(robinconnell): Currently all tracks are encrypted using the same // TODO(robinconnell): Currently all tracks are encrypted using the same
// key_id and key. Add the ability to encrypt each track_type using a // key_id and key. Add the ability to encrypt each stream_label using a
// different key_id and key. // different key_id and key.
DCHECK(key); DCHECK(key);
DCHECK(encryption_key_); DCHECK(encryption_key_);
@ -337,7 +338,7 @@ Status PlayReadyKeySource::GetKey(const std::vector<uint8_t>& key_id,
} }
Status PlayReadyKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index, Status PlayReadyKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type, const std::string& stream_label,
EncryptionKey* key) { EncryptionKey* key) {
// TODO(robinconnell): Implement key rotation. // TODO(robinconnell): Implement key rotation.
*key = *encryption_key_; *key = *encryption_key_;

View File

@ -42,11 +42,11 @@ class PlayReadyKeySource : public KeySource {
/// @{ /// @{
Status FetchKeys(EmeInitDataType init_data_type, Status FetchKeys(EmeInitDataType init_data_type,
const std::vector<uint8_t>& init_data) override; const std::vector<uint8_t>& init_data) override;
Status GetKey(TrackType track_type, EncryptionKey* key) override; Status GetKey(const std::string& stream_label, EncryptionKey* key) override;
Status GetKey(const std::vector<uint8_t>& key_id, Status GetKey(const std::vector<uint8_t>& key_id,
EncryptionKey* key) override; EncryptionKey* key) override;
Status GetCryptoPeriodKey(uint32_t crypto_period_index, Status GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type, const std::string& stream_label,
EncryptionKey* key) override; EncryptionKey* key) override;
/// @} /// @}
virtual Status FetchKeysWithProgramIdentifier(const std::string& program_identifier); virtual Status FetchKeysWithProgramIdentifier(const std::string& program_identifier);

View File

@ -237,13 +237,14 @@ Status WidevineKeySource::FetchKeys(EmeInitDataType init_data_type,
return FetchKeysInternal(!kEnableKeyRotation, 0, widevine_classic); return FetchKeysInternal(!kEnableKeyRotation, 0, widevine_classic);
} }
Status WidevineKeySource::GetKey(TrackType track_type, EncryptionKey* key) { Status WidevineKeySource::GetKey(const std::string& stream_label,
EncryptionKey* key) {
DCHECK(key); DCHECK(key);
if (encryption_key_map_.find(track_type) == encryption_key_map_.end()) { if (encryption_key_map_.find(stream_label) == encryption_key_map_.end()) {
return Status(error::INTERNAL_ERROR, return Status(error::INTERNAL_ERROR,
"Cannot find key of type " + TrackTypeToString(track_type)); "Cannot find key for '" + stream_label + "'.");
} }
*key = *encryption_key_map_[track_type]; *key = *encryption_key_map_[stream_label];
return Status::OK; return Status::OK;
} }
@ -261,7 +262,7 @@ Status WidevineKeySource::GetKey(const std::vector<uint8_t>& key_id,
} }
Status WidevineKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index, Status WidevineKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type, const std::string& stream_label,
EncryptionKey* key) { EncryptionKey* key) {
DCHECK(key_production_thread_.HasBeenStarted()); DCHECK(key_production_thread_.HasBeenStarted());
// TODO(kqyang): This is not elegant. Consider refactoring later. // TODO(kqyang): This is not elegant. Consider refactoring later.
@ -279,7 +280,7 @@ Status WidevineKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
key_production_started_ = true; key_production_started_ = true;
} }
} }
return GetKeyInternal(crypto_period_index, track_type, key); return GetKeyInternal(crypto_period_index, stream_label, key);
} }
void WidevineKeySource::set_signer(std::unique_ptr<RequestSigner> signer) { void WidevineKeySource::set_signer(std::unique_ptr<RequestSigner> signer) {
@ -292,12 +293,10 @@ void WidevineKeySource::set_key_fetcher(
} }
Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index, Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
TrackType track_type, const std::string& stream_label,
EncryptionKey* key) { EncryptionKey* key) {
DCHECK(key_pool_); DCHECK(key_pool_);
DCHECK(key); DCHECK(key);
DCHECK_LE(track_type, NUM_VALID_TRACK_TYPES);
DCHECK_NE(track_type, TRACK_TYPE_UNKNOWN);
std::shared_ptr<EncryptionKeyMap> encryption_key_map; std::shared_ptr<EncryptionKeyMap> encryption_key_map;
Status status = key_pool_->Peek(crypto_period_index, &encryption_key_map, Status status = key_pool_->Peek(crypto_period_index, &encryption_key_map,
@ -310,11 +309,11 @@ Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
return status; return status;
} }
if (encryption_key_map->find(track_type) == encryption_key_map->end()) { if (encryption_key_map->find(stream_label) == encryption_key_map->end()) {
return Status(error::INTERNAL_ERROR, return Status(error::INTERNAL_ERROR,
"Cannot find key of type " + TrackTypeToString(track_type)); "Cannot find key for '" + stream_label + "'.");
} }
*key = *encryption_key_map->at(track_type); *key = *encryption_key_map->at(stream_label);
return Status::OK; return Status::OK;
} }
@ -545,11 +544,9 @@ bool WidevineKeySource::ExtractEncryptionKey(
} }
} }
std::string track_type_str; std::string stream_label;
RCHECK(track_dict->GetString("type", &track_type_str)); RCHECK(track_dict->GetString("type", &stream_label));
TrackType track_type = GetTrackTypeFromString(track_type_str); RCHECK(encryption_key_map.find(stream_label) == encryption_key_map.end());
DCHECK_NE(TRACK_TYPE_UNKNOWN, track_type);
RCHECK(encryption_key_map.find(track_type) == encryption_key_map.end());
std::unique_ptr<EncryptionKey> encryption_key(new EncryptionKey()); std::unique_ptr<EncryptionKey> encryption_key(new EncryptionKey());
@ -573,7 +570,7 @@ bool WidevineKeySource::ExtractEncryptionKey(
encryption_key->key_system_info.push_back(info); encryption_key->key_system_info.push_back(info);
} }
encryption_key_map[track_type] = std::move(encryption_key); encryption_key_map[stream_label] = std::move(encryption_key);
} }
// If the flag exists, create a common system ID PSSH box that contains the // If the flag exists, create a common system ID PSSH box that contains the

View File

@ -39,11 +39,11 @@ class WidevineKeySource : public KeySource {
/// @{ /// @{
Status FetchKeys(EmeInitDataType init_data_type, Status FetchKeys(EmeInitDataType init_data_type,
const std::vector<uint8_t>& init_data) override; const std::vector<uint8_t>& init_data) override;
Status GetKey(TrackType track_type, EncryptionKey* key) override; Status GetKey(const std::string& stream_label, EncryptionKey* key) override;
Status GetKey(const std::vector<uint8_t>& key_id, Status GetKey(const std::vector<uint8_t>& key_id,
EncryptionKey* key) override; EncryptionKey* key) override;
Status GetCryptoPeriodKey(uint32_t crypto_period_index, Status GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type, const std::string& stream_label,
EncryptionKey* key) override; EncryptionKey* key) override;
/// @} /// @}
@ -68,13 +68,14 @@ class WidevineKeySource : public KeySource {
void set_key_fetcher(std::unique_ptr<KeyFetcher> key_fetcher); void set_key_fetcher(std::unique_ptr<KeyFetcher> key_fetcher);
private: private:
typedef std::map<TrackType, std::unique_ptr<EncryptionKey>> EncryptionKeyMap; typedef std::map<std::string, std::unique_ptr<EncryptionKey>>
EncryptionKeyMap;
typedef ProducerConsumerQueue<std::shared_ptr<EncryptionKeyMap>> typedef ProducerConsumerQueue<std::shared_ptr<EncryptionKeyMap>>
EncryptionKeyQueue; EncryptionKeyQueue;
// Internal routine for getting keys. // Internal routine for getting keys.
Status GetKeyInternal(uint32_t crypto_period_index, Status GetKeyInternal(uint32_t crypto_period_index,
TrackType track_type, const std::string& stream_label,
EncryptionKey* key); EncryptionKey* key);
// The closure task to fetch keys repeatedly. // The closure task to fetch keys repeatedly.

View File

@ -197,17 +197,14 @@ class WidevineKeySourceTest : public Test {
void VerifyKeys(bool classic) { void VerifyKeys(bool classic) {
EncryptionKey encryption_key; EncryptionKey encryption_key;
const std::string kTrackTypes[] = {"SD", "HD", "UHD1", "UHD2", "AUDIO"}; const std::string kStreamLabels[] = {"SD", "HD", "UHD1", "UHD2", "AUDIO"};
for (size_t i = 0; i < arraysize(kTrackTypes); ++i) { for (const std::string& stream_label : kStreamLabels) {
ASSERT_OK(widevine_key_source_->GetKey( ASSERT_OK(widevine_key_source_->GetKey(stream_label, &encryption_key));
KeySource::GetTrackTypeFromString(kTrackTypes[i]), EXPECT_EQ(GetMockKey(stream_label), ToString(encryption_key.key));
&encryption_key));
EXPECT_EQ(GetMockKey(kTrackTypes[i]), ToString(encryption_key.key));
if (!classic) { if (!classic) {
ASSERT_EQ(add_common_pssh_ ? 2u : 1u, ASSERT_EQ(add_common_pssh_ ? 2u : 1u,
encryption_key.key_system_info.size()); encryption_key.key_system_info.size());
EXPECT_EQ(GetMockKeyId(kTrackTypes[i]), EXPECT_EQ(GetMockKeyId(stream_label), ToString(encryption_key.key_id));
ToString(encryption_key.key_id));
EXPECT_EQ(GetMockPsshData(), EXPECT_EQ(GetMockPsshData(),
ToString(encryption_key.key_system_info[0].pssh_data())); ToString(encryption_key.key_system_info[0].pssh_data()));
@ -220,10 +217,10 @@ class WidevineKeySourceTest : public Test {
const std::vector<std::vector<uint8_t>>& key_ids = const std::vector<std::vector<uint8_t>>& key_ids =
encryption_key.key_system_info[1].key_ids(); encryption_key.key_system_info[1].key_ids();
ASSERT_EQ(arraysize(kTrackTypes), key_ids.size()); ASSERT_EQ(arraysize(kStreamLabels), key_ids.size());
for (size_t j = 0; j < arraysize(kTrackTypes); ++j) { for (const std::string& stream_label : kStreamLabels) {
// Because they are stored in a std::set, the order may change. // Because they are stored in a std::set, the order may change.
const std::string key_id_str = GetMockKeyId(kTrackTypes[j]); const std::string key_id_str = GetMockKeyId(stream_label);
const std::vector<uint8_t> key_id(key_id_str.begin(), const std::vector<uint8_t> key_id(key_id_str.begin(),
key_id_str.end()); key_id_str.end());
EXPECT_THAT(key_ids, testing::Contains(key_id)); EXPECT_THAT(key_ids, testing::Contains(key_id));
@ -243,21 +240,6 @@ class WidevineKeySourceTest : public Test {
DISALLOW_COPY_AND_ASSIGN(WidevineKeySourceTest); DISALLOW_COPY_AND_ASSIGN(WidevineKeySourceTest);
}; };
TEST_F(WidevineKeySourceTest, GetTrackTypeFromString) {
EXPECT_EQ(KeySource::TRACK_TYPE_SD,
KeySource::GetTrackTypeFromString("SD"));
EXPECT_EQ(KeySource::TRACK_TYPE_HD,
KeySource::GetTrackTypeFromString("HD"));
EXPECT_EQ(KeySource::TRACK_TYPE_UHD1,
KeySource::GetTrackTypeFromString("UHD1"));
EXPECT_EQ(KeySource::TRACK_TYPE_UHD2,
KeySource::GetTrackTypeFromString("UHD2"));
EXPECT_EQ(KeySource::TRACK_TYPE_AUDIO,
KeySource::GetTrackTypeFromString("AUDIO"));
EXPECT_EQ(KeySource::TRACK_TYPE_UNKNOWN,
KeySource::GetTrackTypeFromString("FOO"));
}
TEST_F(WidevineKeySourceTest, GenerateSignatureFailure) { TEST_F(WidevineKeySourceTest, GenerateSignatureFailure) {
EXPECT_CALL(*mock_request_signer_, GenerateSignature(_, _)) EXPECT_CALL(*mock_request_signer_, GenerateSignature(_, _))
.WillOnce(Return(false)); .WillOnce(Return(false));
@ -531,23 +513,19 @@ TEST_P(WidevineKeySourceParameterizedTest, KeyRotationTest) {
ASSERT_OK(widevine_key_source_->FetchKeys(content_id_, kPolicy)); ASSERT_OK(widevine_key_source_->FetchKeys(content_id_, kPolicy));
EncryptionKey encryption_key; EncryptionKey encryption_key;
const std::string kStreamLabels[] = {"SD", "HD", "UHD1", "UHD2", "AUDIO"};
for (size_t i = 0; i < arraysize(kCryptoPeriodIndexes); ++i) { for (size_t i = 0; i < arraysize(kCryptoPeriodIndexes); ++i) {
const std::string kTrackTypes[] = {"SD", "HD", "UHD1", "UHD2", "AUDIO"}; for (const std::string& stream_label : kStreamLabels) {
for (size_t j = 0; j < 5; ++j) {
ASSERT_OK(widevine_key_source_->GetCryptoPeriodKey( ASSERT_OK(widevine_key_source_->GetCryptoPeriodKey(
kCryptoPeriodIndexes[i], kCryptoPeriodIndexes[i], stream_label, &encryption_key));
KeySource::GetTrackTypeFromString(kTrackTypes[j]), EXPECT_EQ(GetMockKey(stream_label, kCryptoPeriodIndexes[i]),
&encryption_key));
EXPECT_EQ(GetMockKey(kTrackTypes[j], kCryptoPeriodIndexes[i]),
ToString(encryption_key.key)); ToString(encryption_key.key));
} }
} }
// The old crypto period indexes should have been garbage collected. // The old crypto period indexes should have been garbage collected.
Status status = widevine_key_source_->GetCryptoPeriodKey( Status status = widevine_key_source_->GetCryptoPeriodKey(
kFirstCryptoPeriodIndex, kFirstCryptoPeriodIndex, kStreamLabels[0], &encryption_key);
KeySource::TRACK_TYPE_SD,
&encryption_key);
EXPECT_EQ(error::INVALID_ARGUMENT, status.error_code()); EXPECT_EQ(error::INVALID_ARGUMENT, status.error_code());
} }

View File

@ -56,18 +56,7 @@ uint8_t GetNaluLengthSize(const StreamInfo& stream_info) {
return video_stream_info.nalu_length_size(); return video_stream_info.nalu_length_size();
} }
// TODO(kqyang): Update KeySource to accept string base stream label. std::string GetStreamLabelForEncryption(
KeySource::TrackType ToTrackType(const std::string& track_type) {
if (track_type == "SD")
return KeySource::TRACK_TYPE_SD;
if (track_type == "HD")
return KeySource::TRACK_TYPE_HD;
if (track_type == "AUDIO")
return KeySource::TRACK_TYPE_AUDIO;
return KeySource::TRACK_TYPE_SD;
}
KeySource::TrackType GetTrackTypeForEncryption(
const StreamInfo& stream_info, const StreamInfo& stream_info,
const std::function<std::string( const std::function<std::string(
const EncryptionParams::EncryptedStreamAttributes& stream_attributes)>& const EncryptionParams::EncryptedStreamAttributes& stream_attributes)>&
@ -84,7 +73,7 @@ KeySource::TrackType GetTrackTypeForEncryption(
stream_attributes.oneof.video.width = video_stream_info.width(); stream_attributes.oneof.video.width = video_stream_info.width();
stream_attributes.oneof.video.height = video_stream_info.height(); stream_attributes.oneof.video.height = video_stream_info.height();
} }
return ToTrackType(stream_label_func(stream_attributes)); return stream_label_func(stream_attributes);
} }
} // namespace } // namespace
@ -151,7 +140,7 @@ Status EncryptionHandler::ProcessStreamInfo(StreamInfo* stream_info) {
stream_info->time_scale(); stream_info->time_scale();
codec_ = stream_info->codec(); codec_ = stream_info->codec();
nalu_length_size_ = GetNaluLengthSize(*stream_info); nalu_length_size_ = GetNaluLengthSize(*stream_info);
track_type_ = GetTrackTypeForEncryption( stream_label_ = GetStreamLabelForEncryption(
*stream_info, encryption_options_.stream_label_func); *stream_info, encryption_options_.stream_label_func);
switch (codec_) { switch (codec_) {
case kCodecVP9: case kCodecVP9:
@ -195,7 +184,7 @@ Status EncryptionHandler::ProcessStreamInfo(StreamInfo* stream_info) {
// convenience. // convenience.
encryption_key.key = encryption_key.key_id; encryption_key.key = encryption_key.key_id;
} else { } else {
status = key_source_->GetKey(track_type_, &encryption_key); status = key_source_->GetKey(stream_label_, &encryption_key);
if (!status.ok()) if (!status.ok())
return status; return status;
} }
@ -228,7 +217,7 @@ Status EncryptionHandler::ProcessMediaSample(MediaSample* sample) {
if (current_crypto_period_index != prev_crypto_period_index_) { if (current_crypto_period_index != prev_crypto_period_index_) {
EncryptionKey encryption_key; EncryptionKey encryption_key;
Status status = key_source_->GetCryptoPeriodKey( Status status = key_source_->GetCryptoPeriodKey(
current_crypto_period_index, track_type_, &encryption_key); current_crypto_period_index, stream_label_, &encryption_key);
if (!status.ok()) if (!status.ok())
return status; return status;
if (!CreateEncryptor(encryption_key)) if (!CreateEncryptor(encryption_key))

View File

@ -78,7 +78,7 @@ class EncryptionHandler : public MediaHandler {
const EncryptionOptions encryption_options_; const EncryptionOptions encryption_options_;
KeySource* key_source_ = nullptr; KeySource* key_source_ = nullptr;
KeySource::TrackType track_type_ = KeySource::TRACK_TYPE_UNKNOWN; std::string stream_label_;
// Current encryption config and encryptor. // Current encryption config and encryptor.
std::shared_ptr<EncryptionConfig> encryption_config_; std::shared_ptr<EncryptionConfig> encryption_config_;
std::unique_ptr<AesCryptor> encryptor_; std::unique_ptr<AesCryptor> encryptor_;

View File

@ -56,10 +56,11 @@ const uint8_t kKeyRotationDefaultKeyId[] = {
class MockKeySource : public FixedKeySource { class MockKeySource : public FixedKeySource {
public: public:
MOCK_METHOD2(GetKey, Status(TrackType track_type, EncryptionKey* key)); MOCK_METHOD2(GetKey,
Status(const std::string& stream_label, EncryptionKey* key));
MOCK_METHOD3(GetCryptoPeriodKey, MOCK_METHOD3(GetCryptoPeriodKey,
Status(uint32_t crypto_period_index, Status(uint32_t crypto_period_index,
TrackType track_type, const std::string& stream_label,
EncryptionKey* key)); EncryptionKey* key));
}; };
@ -657,7 +658,7 @@ TEST_F(EncryptionHandlerTrackTypeTest, AudioTrackType) {
return kAudioStreamLabel; return kAudioStreamLabel;
}; };
SetUpEncryptionHandler(encryption_options); SetUpEncryptionHandler(encryption_options);
EXPECT_CALL(mock_key_source_, GetKey(KeySource::TRACK_TYPE_AUDIO, _)) EXPECT_CALL(mock_key_source_, GetKey(kAudioStreamLabel, _))
.WillOnce( .WillOnce(
DoAll(SetArgPointee<1>(GetMockEncryptionKey()), Return(Status::OK))); DoAll(SetArgPointee<1>(GetMockEncryptionKey()), Return(Status::OK)));
ASSERT_OK(Process(GetAudioStreamInfoStreamData(kStreamIndex, kTimeScale))); ASSERT_OK(Process(GetAudioStreamInfoStreamData(kStreamIndex, kTimeScale)));
@ -676,7 +677,7 @@ TEST_F(EncryptionHandlerTrackTypeTest, VideoTrackType) {
return kSdVideoStreamLabel; return kSdVideoStreamLabel;
}; };
SetUpEncryptionHandler(encryption_options); SetUpEncryptionHandler(encryption_options);
EXPECT_CALL(mock_key_source_, GetKey(KeySource::TRACK_TYPE_SD, _)) EXPECT_CALL(mock_key_source_, GetKey(kSdVideoStreamLabel, _))
.WillOnce( .WillOnce(
DoAll(SetArgPointee<1>(GetMockEncryptionKey()), Return(Status::OK))); DoAll(SetArgPointee<1>(GetMockEncryptionKey()), Return(Status::OK)));
std::unique_ptr<StreamData> stream_data = std::unique_ptr<StreamData> stream_data =

View File

@ -1076,8 +1076,8 @@ bool WvmMediaParser::GetAssetKey(const uint8_t* asset_id,
return false; return false;
} }
status = decryption_key_source_->GetKey(KeySource::TRACK_TYPE_HD, const char kHdStreamLabel[] = "HD";
encryption_key); status = decryption_key_source_->GetKey(kHdStreamLabel, encryption_key);
if (!status.ok()) { if (!status.ok()) {
LOG(ERROR) << "Fetch Key(s) failed for AssetID = " LOG(ERROR) << "Fetch Key(s) failed for AssetID = "
<< ntohlFromBuffer(asset_id) << ", error = " << status; << ntohlFromBuffer(asset_id) << ", error = " << status;

View File

@ -55,8 +55,8 @@ class MockKeySource : public FixedKeySource {
MOCK_METHOD2(FetchKeys, MOCK_METHOD2(FetchKeys,
Status(EmeInitDataType init_data_type, Status(EmeInitDataType init_data_type,
const std::vector<uint8_t>& init_data)); const std::vector<uint8_t>& init_data));
MOCK_METHOD2(GetKey, Status(TrackType track_type, MOCK_METHOD2(GetKey,
EncryptionKey* key)); Status(const std::string& stream_label, EncryptionKey* key));
private: private:
DISALLOW_COPY_AND_ASSIGN(MockKeySource); DISALLOW_COPY_AND_ASSIGN(MockKeySource);