Use free-form strings for stream labels (track types)

Change-Id: I38489acbdfaf4bb491635fdc7f6b0cab77a53574
This commit is contained in:
KongQun Yang 2017-06-13 14:54:12 -07:00
parent 05a5a41969
commit 1aeedd102e
15 changed files with 70 additions and 153 deletions

View File

@ -21,7 +21,8 @@ Status FixedKeySource::FetchKeys(EmeInitDataType init_data_type,
return Status::OK;
}
Status FixedKeySource::GetKey(TrackType track_type, EncryptionKey* key) {
Status FixedKeySource::GetKey(const std::string& stream_label,
EncryptionKey* key) {
DCHECK(key);
DCHECK(encryption_key_);
*key = *encryption_key_;
@ -43,7 +44,7 @@ Status FixedKeySource::GetKey(const std::vector<uint8_t>& key_id,
}
Status FixedKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type,
const std::string& stream_label,
EncryptionKey* key) {
// Create a copy of the key.
*key = *encryption_key_;

View File

@ -32,11 +32,11 @@ class FixedKeySource : public KeySource {
/// @{
Status FetchKeys(EmeInitDataType init_data_type,
const std::vector<uint8_t>& init_data) override;
Status GetKey(TrackType track_type, EncryptionKey* key) override;
Status GetKey(const std::string& stream_label, EncryptionKey* key) override;
Status GetKey(const std::vector<uint8_t>& key_id,
EncryptionKey* key) override;
Status GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type,
const std::string& stream_label,
EncryptionKey* key) override;
/// @}

View File

@ -53,7 +53,7 @@ TEST(FixedKeySourceTest, CreateFromHexStrings_Succes) {
ASSERT_NE(nullptr, key_source);
EncryptionKey key;
ASSERT_OK(key_source->GetKey(KeySource::TRACK_TYPE_SD, &key));
ASSERT_OK(key_source->GetKey("SomeStreamLabel", &key));
EXPECT_HEX_EQ(kKeyIdHex, key.key_id);
EXPECT_HEX_EQ(kKeyHex, key.key);
@ -70,7 +70,7 @@ TEST(FixedKeySourceTest, CreateFromHexStrings_EmptyPssh) {
ASSERT_NE(nullptr, key_source);
EncryptionKey key;
ASSERT_OK(key_source->GetKey(KeySource::TRACK_TYPE_SD, &key));
ASSERT_OK(key_source->GetKey("SomeStreamLabel", &key));
EXPECT_HEX_EQ(kKeyIdHex, key.key_id);
EXPECT_HEX_EQ(kKeyHex, key.key);

View File

@ -12,47 +12,12 @@ namespace shaka {
namespace media {
EncryptionKey::EncryptionKey() {}
EncryptionKey::~EncryptionKey() {}
KeySource::~KeySource() {}
KeySource::TrackType KeySource::GetTrackTypeFromString(
const std::string& track_type_string) {
if (track_type_string == "SD")
return TRACK_TYPE_SD;
if (track_type_string == "HD")
return TRACK_TYPE_HD;
if (track_type_string == "UHD1")
return TRACK_TYPE_UHD1;
if (track_type_string == "UHD2")
return TRACK_TYPE_UHD2;
if (track_type_string == "AUDIO")
return TRACK_TYPE_AUDIO;
if (track_type_string == "UNSPECIFIED")
return TRACK_TYPE_UNSPECIFIED;
LOG(WARNING) << "Unexpected track type: " << track_type_string;
return TRACK_TYPE_UNKNOWN;
}
std::string KeySource::TrackTypeToString(TrackType track_type) {
switch (track_type) {
case TRACK_TYPE_SD:
return "SD";
case TRACK_TYPE_HD:
return "HD";
case TRACK_TYPE_UHD1:
return "UHD1";
case TRACK_TYPE_UHD2:
return "UHD2";
case TRACK_TYPE_AUDIO:
return "AUDIO";
default:
NOTIMPLEMENTED() << "Unknown track type: " << track_type;
return "UNKNOWN";
}
}
KeySource::KeySource() {}
KeySource::~KeySource() {}
} // namespace media
} // namespace shaka

View File

@ -44,17 +44,6 @@ struct EncryptionKey {
/// KeySource is responsible for encryption key acquisition.
class KeySource {
public:
enum TrackType {
TRACK_TYPE_UNKNOWN = 0,
TRACK_TYPE_SD = 1,
TRACK_TYPE_HD = 2,
TRACK_TYPE_UHD1 = 3,
TRACK_TYPE_UHD2 = 4,
TRACK_TYPE_AUDIO = 5,
TRACK_TYPE_UNSPECIFIED = 6,
NUM_VALID_TRACK_TYPES = 6
};
KeySource();
virtual ~KeySource();
@ -65,12 +54,13 @@ class KeySource {
virtual Status FetchKeys(EmeInitDataType init_data_type,
const std::vector<uint8_t>& init_data) = 0;
/// Get encryption key of the specified track type.
/// @param track_type is the type of track for which retrieving the key.
/// Get encryption key of the specified stream label.
/// @param stream_label is the label of stream for which retrieving the key.
/// @param key is a pointer to the EncryptionKey which will hold the retrieved
/// key. Owner retains ownership, and may not be NULL.
/// @return OK on success, an error status otherwise.
virtual Status GetKey(TrackType track_type, EncryptionKey* key) = 0;
virtual Status GetKey(const std::string& stream_label,
EncryptionKey* key) = 0;
/// Get the encryption key specified by the CENC key ID.
/// @param key_id is the unique identifier for the key being retreived.
@ -83,20 +73,14 @@ class KeySource {
/// Get encryption key of the specified track type at the specified index.
/// @param crypto_period_index is the sequence number of the key rotation
/// period for which the key is being retrieved.
/// @param track_type is the type of track for which retrieving the key.
/// @param stream_label is the label of stream for which retrieving the key.
/// @param key is a pointer to the EncryptionKey which will hold the retrieved
/// key. Owner retains ownership, and may not be NULL.
/// @return OK on success, an error status otherwise.
virtual Status GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type,
const std::string& stream_label,
EncryptionKey* key) = 0;
/// Convert string representation of track type to enum representation.
static TrackType GetTrackTypeFromString(const std::string& track_type_string);
/// Convert TrackType to string.
static std::string TrackTypeToString(TrackType track_type);
private:
DISALLOW_COPY_AND_ASSIGN(KeySource);
};

View File

@ -316,9 +316,10 @@ Status PlayReadyKeySource::FetchKeys(EmeInitDataType init_data_type,
return Status::OK;
}
Status PlayReadyKeySource::GetKey(TrackType track_type, EncryptionKey* key) {
Status PlayReadyKeySource::GetKey(const std::string& stream_label,
EncryptionKey* key) {
// TODO(robinconnell): Currently all tracks are encrypted using the same
// key_id and key. Add the ability to encrypt each track_type using a
// key_id and key. Add the ability to encrypt each stream_label using a
// different key_id and key.
DCHECK(key);
DCHECK(encryption_key_);
@ -337,7 +338,7 @@ Status PlayReadyKeySource::GetKey(const std::vector<uint8_t>& key_id,
}
Status PlayReadyKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type,
const std::string& stream_label,
EncryptionKey* key) {
// TODO(robinconnell): Implement key rotation.
*key = *encryption_key_;

View File

@ -42,11 +42,11 @@ class PlayReadyKeySource : public KeySource {
/// @{
Status FetchKeys(EmeInitDataType init_data_type,
const std::vector<uint8_t>& init_data) override;
Status GetKey(TrackType track_type, EncryptionKey* key) override;
Status GetKey(const std::string& stream_label, EncryptionKey* key) override;
Status GetKey(const std::vector<uint8_t>& key_id,
EncryptionKey* key) override;
Status GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type,
const std::string& stream_label,
EncryptionKey* key) override;
/// @}
virtual Status FetchKeysWithProgramIdentifier(const std::string& program_identifier);

View File

@ -237,13 +237,14 @@ Status WidevineKeySource::FetchKeys(EmeInitDataType init_data_type,
return FetchKeysInternal(!kEnableKeyRotation, 0, widevine_classic);
}
Status WidevineKeySource::GetKey(TrackType track_type, EncryptionKey* key) {
Status WidevineKeySource::GetKey(const std::string& stream_label,
EncryptionKey* key) {
DCHECK(key);
if (encryption_key_map_.find(track_type) == encryption_key_map_.end()) {
if (encryption_key_map_.find(stream_label) == encryption_key_map_.end()) {
return Status(error::INTERNAL_ERROR,
"Cannot find key of type " + TrackTypeToString(track_type));
"Cannot find key for '" + stream_label + "'.");
}
*key = *encryption_key_map_[track_type];
*key = *encryption_key_map_[stream_label];
return Status::OK;
}
@ -261,7 +262,7 @@ Status WidevineKeySource::GetKey(const std::vector<uint8_t>& key_id,
}
Status WidevineKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type,
const std::string& stream_label,
EncryptionKey* key) {
DCHECK(key_production_thread_.HasBeenStarted());
// TODO(kqyang): This is not elegant. Consider refactoring later.
@ -279,7 +280,7 @@ Status WidevineKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
key_production_started_ = true;
}
}
return GetKeyInternal(crypto_period_index, track_type, key);
return GetKeyInternal(crypto_period_index, stream_label, key);
}
void WidevineKeySource::set_signer(std::unique_ptr<RequestSigner> signer) {
@ -292,12 +293,10 @@ void WidevineKeySource::set_key_fetcher(
}
Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
TrackType track_type,
const std::string& stream_label,
EncryptionKey* key) {
DCHECK(key_pool_);
DCHECK(key);
DCHECK_LE(track_type, NUM_VALID_TRACK_TYPES);
DCHECK_NE(track_type, TRACK_TYPE_UNKNOWN);
std::shared_ptr<EncryptionKeyMap> encryption_key_map;
Status status = key_pool_->Peek(crypto_period_index, &encryption_key_map,
@ -310,11 +309,11 @@ Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
return status;
}
if (encryption_key_map->find(track_type) == encryption_key_map->end()) {
if (encryption_key_map->find(stream_label) == encryption_key_map->end()) {
return Status(error::INTERNAL_ERROR,
"Cannot find key of type " + TrackTypeToString(track_type));
"Cannot find key for '" + stream_label + "'.");
}
*key = *encryption_key_map->at(track_type);
*key = *encryption_key_map->at(stream_label);
return Status::OK;
}
@ -545,11 +544,9 @@ bool WidevineKeySource::ExtractEncryptionKey(
}
}
std::string track_type_str;
RCHECK(track_dict->GetString("type", &track_type_str));
TrackType track_type = GetTrackTypeFromString(track_type_str);
DCHECK_NE(TRACK_TYPE_UNKNOWN, track_type);
RCHECK(encryption_key_map.find(track_type) == encryption_key_map.end());
std::string stream_label;
RCHECK(track_dict->GetString("type", &stream_label));
RCHECK(encryption_key_map.find(stream_label) == encryption_key_map.end());
std::unique_ptr<EncryptionKey> encryption_key(new EncryptionKey());
@ -573,7 +570,7 @@ bool WidevineKeySource::ExtractEncryptionKey(
encryption_key->key_system_info.push_back(info);
}
encryption_key_map[track_type] = std::move(encryption_key);
encryption_key_map[stream_label] = std::move(encryption_key);
}
// If the flag exists, create a common system ID PSSH box that contains the

View File

@ -39,11 +39,11 @@ class WidevineKeySource : public KeySource {
/// @{
Status FetchKeys(EmeInitDataType init_data_type,
const std::vector<uint8_t>& init_data) override;
Status GetKey(TrackType track_type, EncryptionKey* key) override;
Status GetKey(const std::string& stream_label, EncryptionKey* key) override;
Status GetKey(const std::vector<uint8_t>& key_id,
EncryptionKey* key) override;
Status GetCryptoPeriodKey(uint32_t crypto_period_index,
TrackType track_type,
const std::string& stream_label,
EncryptionKey* key) override;
/// @}
@ -68,13 +68,14 @@ class WidevineKeySource : public KeySource {
void set_key_fetcher(std::unique_ptr<KeyFetcher> key_fetcher);
private:
typedef std::map<TrackType, std::unique_ptr<EncryptionKey>> EncryptionKeyMap;
typedef std::map<std::string, std::unique_ptr<EncryptionKey>>
EncryptionKeyMap;
typedef ProducerConsumerQueue<std::shared_ptr<EncryptionKeyMap>>
EncryptionKeyQueue;
// Internal routine for getting keys.
Status GetKeyInternal(uint32_t crypto_period_index,
TrackType track_type,
const std::string& stream_label,
EncryptionKey* key);
// The closure task to fetch keys repeatedly.

View File

@ -197,17 +197,14 @@ class WidevineKeySourceTest : public Test {
void VerifyKeys(bool classic) {
EncryptionKey encryption_key;
const std::string kTrackTypes[] = {"SD", "HD", "UHD1", "UHD2", "AUDIO"};
for (size_t i = 0; i < arraysize(kTrackTypes); ++i) {
ASSERT_OK(widevine_key_source_->GetKey(
KeySource::GetTrackTypeFromString(kTrackTypes[i]),
&encryption_key));
EXPECT_EQ(GetMockKey(kTrackTypes[i]), ToString(encryption_key.key));
const std::string kStreamLabels[] = {"SD", "HD", "UHD1", "UHD2", "AUDIO"};
for (const std::string& stream_label : kStreamLabels) {
ASSERT_OK(widevine_key_source_->GetKey(stream_label, &encryption_key));
EXPECT_EQ(GetMockKey(stream_label), ToString(encryption_key.key));
if (!classic) {
ASSERT_EQ(add_common_pssh_ ? 2u : 1u,
encryption_key.key_system_info.size());
EXPECT_EQ(GetMockKeyId(kTrackTypes[i]),
ToString(encryption_key.key_id));
EXPECT_EQ(GetMockKeyId(stream_label), ToString(encryption_key.key_id));
EXPECT_EQ(GetMockPsshData(),
ToString(encryption_key.key_system_info[0].pssh_data()));
@ -220,10 +217,10 @@ class WidevineKeySourceTest : public Test {
const std::vector<std::vector<uint8_t>>& key_ids =
encryption_key.key_system_info[1].key_ids();
ASSERT_EQ(arraysize(kTrackTypes), key_ids.size());
for (size_t j = 0; j < arraysize(kTrackTypes); ++j) {
ASSERT_EQ(arraysize(kStreamLabels), key_ids.size());
for (const std::string& stream_label : kStreamLabels) {
// Because they are stored in a std::set, the order may change.
const std::string key_id_str = GetMockKeyId(kTrackTypes[j]);
const std::string key_id_str = GetMockKeyId(stream_label);
const std::vector<uint8_t> key_id(key_id_str.begin(),
key_id_str.end());
EXPECT_THAT(key_ids, testing::Contains(key_id));
@ -243,21 +240,6 @@ class WidevineKeySourceTest : public Test {
DISALLOW_COPY_AND_ASSIGN(WidevineKeySourceTest);
};
TEST_F(WidevineKeySourceTest, GetTrackTypeFromString) {
EXPECT_EQ(KeySource::TRACK_TYPE_SD,
KeySource::GetTrackTypeFromString("SD"));
EXPECT_EQ(KeySource::TRACK_TYPE_HD,
KeySource::GetTrackTypeFromString("HD"));
EXPECT_EQ(KeySource::TRACK_TYPE_UHD1,
KeySource::GetTrackTypeFromString("UHD1"));
EXPECT_EQ(KeySource::TRACK_TYPE_UHD2,
KeySource::GetTrackTypeFromString("UHD2"));
EXPECT_EQ(KeySource::TRACK_TYPE_AUDIO,
KeySource::GetTrackTypeFromString("AUDIO"));
EXPECT_EQ(KeySource::TRACK_TYPE_UNKNOWN,
KeySource::GetTrackTypeFromString("FOO"));
}
TEST_F(WidevineKeySourceTest, GenerateSignatureFailure) {
EXPECT_CALL(*mock_request_signer_, GenerateSignature(_, _))
.WillOnce(Return(false));
@ -531,23 +513,19 @@ TEST_P(WidevineKeySourceParameterizedTest, KeyRotationTest) {
ASSERT_OK(widevine_key_source_->FetchKeys(content_id_, kPolicy));
EncryptionKey encryption_key;
const std::string kStreamLabels[] = {"SD", "HD", "UHD1", "UHD2", "AUDIO"};
for (size_t i = 0; i < arraysize(kCryptoPeriodIndexes); ++i) {
const std::string kTrackTypes[] = {"SD", "HD", "UHD1", "UHD2", "AUDIO"};
for (size_t j = 0; j < 5; ++j) {
for (const std::string& stream_label : kStreamLabels) {
ASSERT_OK(widevine_key_source_->GetCryptoPeriodKey(
kCryptoPeriodIndexes[i],
KeySource::GetTrackTypeFromString(kTrackTypes[j]),
&encryption_key));
EXPECT_EQ(GetMockKey(kTrackTypes[j], kCryptoPeriodIndexes[i]),
kCryptoPeriodIndexes[i], stream_label, &encryption_key));
EXPECT_EQ(GetMockKey(stream_label, kCryptoPeriodIndexes[i]),
ToString(encryption_key.key));
}
}
// The old crypto period indexes should have been garbage collected.
Status status = widevine_key_source_->GetCryptoPeriodKey(
kFirstCryptoPeriodIndex,
KeySource::TRACK_TYPE_SD,
&encryption_key);
kFirstCryptoPeriodIndex, kStreamLabels[0], &encryption_key);
EXPECT_EQ(error::INVALID_ARGUMENT, status.error_code());
}

View File

@ -56,18 +56,7 @@ uint8_t GetNaluLengthSize(const StreamInfo& stream_info) {
return video_stream_info.nalu_length_size();
}
// TODO(kqyang): Update KeySource to accept string base stream label.
KeySource::TrackType ToTrackType(const std::string& track_type) {
if (track_type == "SD")
return KeySource::TRACK_TYPE_SD;
if (track_type == "HD")
return KeySource::TRACK_TYPE_HD;
if (track_type == "AUDIO")
return KeySource::TRACK_TYPE_AUDIO;
return KeySource::TRACK_TYPE_SD;
}
KeySource::TrackType GetTrackTypeForEncryption(
std::string GetStreamLabelForEncryption(
const StreamInfo& stream_info,
const std::function<std::string(
const EncryptionParams::EncryptedStreamAttributes& stream_attributes)>&
@ -84,7 +73,7 @@ KeySource::TrackType GetTrackTypeForEncryption(
stream_attributes.oneof.video.width = video_stream_info.width();
stream_attributes.oneof.video.height = video_stream_info.height();
}
return ToTrackType(stream_label_func(stream_attributes));
return stream_label_func(stream_attributes);
}
} // namespace
@ -151,7 +140,7 @@ Status EncryptionHandler::ProcessStreamInfo(StreamInfo* stream_info) {
stream_info->time_scale();
codec_ = stream_info->codec();
nalu_length_size_ = GetNaluLengthSize(*stream_info);
track_type_ = GetTrackTypeForEncryption(
stream_label_ = GetStreamLabelForEncryption(
*stream_info, encryption_options_.stream_label_func);
switch (codec_) {
case kCodecVP9:
@ -195,7 +184,7 @@ Status EncryptionHandler::ProcessStreamInfo(StreamInfo* stream_info) {
// convenience.
encryption_key.key = encryption_key.key_id;
} else {
status = key_source_->GetKey(track_type_, &encryption_key);
status = key_source_->GetKey(stream_label_, &encryption_key);
if (!status.ok())
return status;
}
@ -228,7 +217,7 @@ Status EncryptionHandler::ProcessMediaSample(MediaSample* sample) {
if (current_crypto_period_index != prev_crypto_period_index_) {
EncryptionKey encryption_key;
Status status = key_source_->GetCryptoPeriodKey(
current_crypto_period_index, track_type_, &encryption_key);
current_crypto_period_index, stream_label_, &encryption_key);
if (!status.ok())
return status;
if (!CreateEncryptor(encryption_key))

View File

@ -78,7 +78,7 @@ class EncryptionHandler : public MediaHandler {
const EncryptionOptions encryption_options_;
KeySource* key_source_ = nullptr;
KeySource::TrackType track_type_ = KeySource::TRACK_TYPE_UNKNOWN;
std::string stream_label_;
// Current encryption config and encryptor.
std::shared_ptr<EncryptionConfig> encryption_config_;
std::unique_ptr<AesCryptor> encryptor_;

View File

@ -56,10 +56,11 @@ const uint8_t kKeyRotationDefaultKeyId[] = {
class MockKeySource : public FixedKeySource {
public:
MOCK_METHOD2(GetKey, Status(TrackType track_type, EncryptionKey* key));
MOCK_METHOD2(GetKey,
Status(const std::string& stream_label, EncryptionKey* key));
MOCK_METHOD3(GetCryptoPeriodKey,
Status(uint32_t crypto_period_index,
TrackType track_type,
const std::string& stream_label,
EncryptionKey* key));
};
@ -657,7 +658,7 @@ TEST_F(EncryptionHandlerTrackTypeTest, AudioTrackType) {
return kAudioStreamLabel;
};
SetUpEncryptionHandler(encryption_options);
EXPECT_CALL(mock_key_source_, GetKey(KeySource::TRACK_TYPE_AUDIO, _))
EXPECT_CALL(mock_key_source_, GetKey(kAudioStreamLabel, _))
.WillOnce(
DoAll(SetArgPointee<1>(GetMockEncryptionKey()), Return(Status::OK)));
ASSERT_OK(Process(GetAudioStreamInfoStreamData(kStreamIndex, kTimeScale)));
@ -676,7 +677,7 @@ TEST_F(EncryptionHandlerTrackTypeTest, VideoTrackType) {
return kSdVideoStreamLabel;
};
SetUpEncryptionHandler(encryption_options);
EXPECT_CALL(mock_key_source_, GetKey(KeySource::TRACK_TYPE_SD, _))
EXPECT_CALL(mock_key_source_, GetKey(kSdVideoStreamLabel, _))
.WillOnce(
DoAll(SetArgPointee<1>(GetMockEncryptionKey()), Return(Status::OK)));
std::unique_ptr<StreamData> stream_data =

View File

@ -1076,8 +1076,8 @@ bool WvmMediaParser::GetAssetKey(const uint8_t* asset_id,
return false;
}
status = decryption_key_source_->GetKey(KeySource::TRACK_TYPE_HD,
encryption_key);
const char kHdStreamLabel[] = "HD";
status = decryption_key_source_->GetKey(kHdStreamLabel, encryption_key);
if (!status.ok()) {
LOG(ERROR) << "Fetch Key(s) failed for AssetID = "
<< ntohlFromBuffer(asset_id) << ", error = " << status;

View File

@ -55,8 +55,8 @@ class MockKeySource : public FixedKeySource {
MOCK_METHOD2(FetchKeys,
Status(EmeInitDataType init_data_type,
const std::vector<uint8_t>& init_data));
MOCK_METHOD2(GetKey, Status(TrackType track_type,
EncryptionKey* key));
MOCK_METHOD2(GetKey,
Status(const std::string& stream_label, EncryptionKey* key));
private:
DISALLOW_COPY_AND_ASSIGN(MockKeySource);