From 1f315ba921533eb9ee59160f36ffc13603e8eb30 Mon Sep 17 00:00:00 2001 From: Kongqun Yang Date: Thu, 24 Apr 2014 09:59:07 -0700 Subject: [PATCH] Support key rotation in widevine encryption key source Change-Id: I05ded15fa666119c86a1d3f1c99123b9cda60b49 --- app/packager_main.cc | 5 +- media/base/aes_encryptor.cc | 3 +- media/base/encryption_key_source.cc | 2 +- media/base/encryption_key_source.h | 2 +- media/base/producer_consumer_queue.h | 54 ++++-- .../base/producer_consumer_queue_unittest.cc | 2 - media/base/widevine_encryption_key_source.cc | 176 +++++++++++++++--- media/base/widevine_encryption_key_source.h | 47 ++++- ...widevine_encryption_key_source_unittest.cc | 157 ++++++++++++++-- 9 files changed, 373 insertions(+), 75 deletions(-) diff --git a/app/packager_main.cc b/app/packager_main.cc index 1c24408f30..6da383466f 100644 --- a/app/packager_main.cc +++ b/app/packager_main.cc @@ -76,7 +76,10 @@ scoped_ptr CreateEncryptionKeySource() { } encryption_key_source.reset(new WidevineEncryptionKeySource( - FLAGS_server_url, FLAGS_content_id, signer.Pass())); + FLAGS_server_url, + FLAGS_content_id, + signer.Pass(), + FLAGS_crypto_period_duration == 0 ? kDisableKeyRotation : 0)); } else if (FLAGS_enable_fixed_key_encryption) { encryption_key_source = EncryptionKeySource::CreateFromHexStrings( FLAGS_key_id, FLAGS_key, FLAGS_pssh, ""); diff --git a/media/base/aes_encryptor.cc b/media/base/aes_encryptor.cc index 9ab61c2fee..32f1cda75e 100644 --- a/media/base/aes_encryptor.cc +++ b/media/base/aes_encryptor.cc @@ -169,11 +169,12 @@ void AesCbcEncryptor::Encrypt(const std::string& plaintext, padded_text.append(num_padding_bytes, static_cast(num_padding_bytes)); ciphertext->resize(padded_text.size()); + std::vector iv(iv_); AES_cbc_encrypt(reinterpret_cast(padded_text.data()), reinterpret_cast(string_as_array(ciphertext)), padded_text.size(), encrypt_key_.get(), - &iv_[0], + &iv[0], AES_ENCRYPT); } diff --git a/media/base/encryption_key_source.cc b/media/base/encryption_key_source.cc index 4404fc8df0..d9dd33cd12 100644 --- a/media/base/encryption_key_source.cc +++ b/media/base/encryption_key_source.cc @@ -30,7 +30,7 @@ Status EncryptionKeySource::GetKey(TrackType track_type, EncryptionKey* key) { return Status::OK; } -Status EncryptionKeySource::GetCryptoPeriodKey(size_t crypto_period_index, +Status EncryptionKeySource::GetCryptoPeriodKey(uint32 crypto_period_index, TrackType track_type, EncryptionKey* key) { NOTIMPLEMENTED(); diff --git a/media/base/encryption_key_source.h b/media/base/encryption_key_source.h index 79486f0bc8..b41870edc9 100644 --- a/media/base/encryption_key_source.h +++ b/media/base/encryption_key_source.h @@ -43,7 +43,7 @@ class EncryptionKeySource { /// Get encryption key of the specified track type at the specified index. /// @return OK on success, an error status otherwise. - virtual Status GetCryptoPeriodKey(size_t crypto_period_index, + virtual Status GetCryptoPeriodKey(uint32 crypto_period_index, TrackType track_type, EncryptionKey* key); diff --git a/media/base/producer_consumer_queue.h b/media/base/producer_consumer_queue.h index b8eda996b7..850935500f 100644 --- a/media/base/producer_consumer_queue.h +++ b/media/base/producer_consumer_queue.h @@ -17,16 +17,26 @@ namespace media { +static const size_t kUnlimitedCapacity = 0u; +static const int64 kInfiniteTimeout = -1; + /// A thread safe producer consumer queue implementation. It allows the standard /// push and pop operations. It also maintains a monotonically-increasing /// element position and allows peeking at the element at certain position. template class ProducerConsumerQueue { public: - /// Create a ProducerConsumerQueue. + /// Create a ProducerConsumerQueue starting from position 0. /// @param capacity is the maximum number of elements that the queue can hold /// at once. A value of zero means unlimited capacity. explicit ProducerConsumerQueue(size_t capacity); + + /// Create a ProducerConsumerQueue starting from indicated position. + /// @param capacity is the maximum number of elements that the queue can hold + /// at once. A value of zero means unlimited capacity. + /// @param starting_pos is the starting head position. + ProducerConsumerQueue(size_t capacity, size_t starting_pos); + ~ProducerConsumerQueue(); /// Push an element to the back of the queue. If the queue has reached its @@ -89,14 +99,14 @@ class ProducerConsumerQueue { /// returned value may be meaningless if the queue is empty. size_t HeadPos() const { base::AutoLock l(lock_); - return head_; + return head_pos_; } /// @return The position of the tail element in the queue. Note that the /// returned value may be meaningless if the queue is empty. size_t TailPos() const { base::AutoLock l(lock_); - return head_ + q_.size() - 1; + return head_pos_ + q_.size() - 1; } /// @return true if the queue has been stopped using Stop(). This allows @@ -107,12 +117,12 @@ class ProducerConsumerQueue { } private: - // Move head_ to center on pos. + // Move head_pos_ to center on pos. void SlideHeadOnCenter(size_t pos); const size_t capacity_; // Maximum number of elements; zero means unlimited. mutable base::Lock lock_; // Lock protecting all other variables below. - size_t head_; // Head position. + size_t head_pos_; // Head position. std::deque q_; // Internal queue holding the elements. base::ConditionVariable not_empty_cv_; base::ConditionVariable not_full_cv_; @@ -126,12 +136,23 @@ class ProducerConsumerQueue { template ProducerConsumerQueue::ProducerConsumerQueue(size_t capacity) : capacity_(capacity), - head_(0), + head_pos_(0), not_empty_cv_(&lock_), not_full_cv_(&lock_), new_element_cv_(&lock_), stop_requested_(false) {} +template +ProducerConsumerQueue::ProducerConsumerQueue(size_t capacity, + size_t starting_pos) + : capacity_(capacity), + head_pos_(starting_pos), + not_empty_cv_(&lock_), + not_full_cv_(&lock_), + new_element_cv_(&lock_), + stop_requested_(false) { +} + template ProducerConsumerQueue::~ProducerConsumerQueue() {} @@ -218,7 +239,7 @@ Status ProducerConsumerQueue::Pop(T* element, int64 timeout_ms) { *element = q_.front(); q_.pop_front(); - ++head_; + ++head_pos_; // Signal other consumers if we have more elements. if (woken && !q_.empty()) @@ -231,10 +252,11 @@ Status ProducerConsumerQueue::Peek(size_t pos, T* element, int64 timeout_ms) { base::AutoLock l(lock_); - if (pos < head_) { + if (pos < head_pos_) { return Status( error::INVALID_ARGUMENT, - base::StringPrintf("pos (%zu) is too small; head is %zu.", pos, head_)); + base::StringPrintf( + "pos (%zu) is too small; head is at %zu.", pos, head_pos_)); } bool woken = false; @@ -242,10 +264,10 @@ Status ProducerConsumerQueue::Peek(size_t pos, base::ElapsedTimer timer; base::TimeDelta timeout_delta = base::TimeDelta::FromMilliseconds(timeout_ms); - // Move head_ to create some space (move the sliding window centered @ pos). + // Move head to create some space (move the sliding window centered @ pos). SlideHeadOnCenter(pos); - while (pos >= head_ + q_.size()) { + while (pos >= head_pos_ + q_.size()) { if (stop_requested_) return Status(error::STOPPED, ""); @@ -262,12 +284,12 @@ Status ProducerConsumerQueue::Peek(size_t pos, return Status(error::TIME_OUT, "Time out on peeking."); } } - // Move head_ to create some space (move the sliding window centered @ pos). + // Move head to create some space (move the sliding window centered @ pos). SlideHeadOnCenter(pos); woken = true; } - *element = q_[pos - head_]; + *element = q_[pos - head_pos_]; // Signal other consumers if we have more elements. if (woken && !q_.empty()) @@ -281,11 +303,11 @@ void ProducerConsumerQueue::SlideHeadOnCenter(size_t pos) { if (capacity_) { // Signal producer to proceed if we are going to create some capacity. - if (q_.size() == capacity_ && pos > head_ + capacity_ / 2) + if (q_.size() == capacity_ && pos > head_pos_ + capacity_ / 2) not_full_cv_.Signal(); - while (!q_.empty() && pos > head_ + capacity_ / 2) { - ++head_; + while (!q_.empty() && pos > head_pos_ + capacity_ / 2) { + ++head_pos_; q_.pop_front(); } } diff --git a/media/base/producer_consumer_queue_unittest.cc b/media/base/producer_consumer_queue_unittest.cc index d27fcd9feb..c31ebdc97f 100644 --- a/media/base/producer_consumer_queue_unittest.cc +++ b/media/base/producer_consumer_queue_unittest.cc @@ -13,10 +13,8 @@ #include "testing/gtest/include/gtest/gtest.h" namespace { -const size_t kUnlimitedCapacity = 0u; const size_t kCapacity = 10u; const int64 kTimeout = 100; // 0.1s. -const int64 kInfiniteTimeout = -1; // Check that the |delta| is approximately |time_in_milliseconds|. bool CheckTimeApproxEqual(int64 time_in_milliseconds, diff --git a/media/base/widevine_encryption_key_source.cc b/media/base/widevine_encryption_key_source.cc index 2dcee2fc2c..cbcf980cc0 100644 --- a/media/base/widevine_encryption_key_source.cc +++ b/media/base/widevine_encryption_key_source.cc @@ -7,8 +7,10 @@ #include "media/base/widevine_encryption_key_source.h" #include "base/base64.h" +#include "base/bind.h" #include "base/json/json_reader.h" #include "base/json/json_writer.h" +#include "base/memory/ref_counted.h" #include "base/stl_util.h" #include "base/values.h" #include "media/base/http_fetcher.h" @@ -34,6 +36,11 @@ const char kLicenseStatusTransientError[] = "INTERNAL_ERROR"; const int kNumTransientErrorRetries = 5; const int kFirstRetryDelayMilliseconds = 1000; +// Default crypto period count, which is the number of keys to fetch on every +// key rotation enabled request. +const int kDefaultCryptoPeriodCount = 10; +const int kGetKeyTimeoutInSeconds = 5 * 60; // 5 minutes. + bool Base64StringToBytes(const std::string& base64_string, std::vector* bytes) { DCHECK(bytes); @@ -93,42 +100,68 @@ bool GetPsshData(const base::DictionaryValue& track_dict, namespace media { +// A ref counted wrapper for EncryptionKeyMap. +class WidevineEncryptionKeySource::RefCountedEncryptionKeyMap + : public base::RefCountedThreadSafe { + public: + explicit RefCountedEncryptionKeyMap(EncryptionKeyMap* encryption_key_map) { + DCHECK(encryption_key_map); + encryption_key_map_.swap(*encryption_key_map); + } + + std::map& map() { + return encryption_key_map_; + } + + private: + friend class base::RefCountedThreadSafe; + + ~RefCountedEncryptionKeyMap() { STLDeleteValues(&encryption_key_map_); } + + EncryptionKeyMap encryption_key_map_; + + DISALLOW_COPY_AND_ASSIGN(RefCountedEncryptionKeyMap); +}; + WidevineEncryptionKeySource::WidevineEncryptionKeySource( const std::string& server_url, const std::string& content_id, - scoped_ptr signer) + scoped_ptr signer, + int first_crypto_period_index) : http_fetcher_(new SimpleHttpFetcher()), server_url_(server_url), content_id_(content_id), signer_(signer.Pass()), - key_fetched_(false) { + key_rotation_enabled_(first_crypto_period_index >= 0), + crypto_period_count_(kDefaultCryptoPeriodCount), + first_crypto_period_index_(first_crypto_period_index), + key_production_thread_( + "KeyProductionThread", + base::Bind(&WidevineEncryptionKeySource::FetchKeysTask, + base::Unretained(this))), + key_pool_(kDefaultCryptoPeriodCount, + key_rotation_enabled_ ? first_crypto_period_index : 0) { DCHECK(signer_); + key_production_thread_.Start(); } + WidevineEncryptionKeySource::~WidevineEncryptionKeySource() { - STLDeleteValues(&encryption_key_map_); + key_pool_.Stop(); + key_production_thread_.Join(); } Status WidevineEncryptionKeySource::GetKey(TrackType track_type, EncryptionKey* key) { - DCHECK(track_type == TRACK_TYPE_SD || track_type == TRACK_TYPE_HD || - track_type == TRACK_TYPE_AUDIO); - Status status; - if (!key_fetched_) { - base::AutoLock auto_lock(lock_); - if (!key_fetched_) { - status = FetchKeys(); - if (status.ok()) - key_fetched_ = true; - } - } - if (!status.ok()) - return status; - if (encryption_key_map_.find(track_type) == encryption_key_map_.end()) { - return Status(error::INTERNAL_ERROR, - "Cannot find key of type " + TrackTypeToString(track_type)); - } - *key = *encryption_key_map_[track_type]; - return Status::OK; + DCHECK(!key_rotation_enabled_); + return GetKeyInternal(0u, track_type, key); +} + +Status WidevineEncryptionKeySource::GetCryptoPeriodKey( + uint32 crypto_period_index, + TrackType track_type, + EncryptionKey* key) { + DCHECK(key_rotation_enabled_); + return GetKeyInternal(crypto_period_index, track_type, key); } void WidevineEncryptionKeySource::set_http_fetcher( @@ -136,9 +169,50 @@ void WidevineEncryptionKeySource::set_http_fetcher( http_fetcher_ = http_fetcher.Pass(); } -Status WidevineEncryptionKeySource::FetchKeys() { +Status WidevineEncryptionKeySource::GetKeyInternal( + uint32 crypto_period_index, + TrackType track_type, + EncryptionKey* key) { + DCHECK_LE(track_type, NUM_VALID_TRACK_TYPES); + DCHECK_NE(track_type, TRACK_TYPE_UNKNOWN); + + scoped_refptr ref_counted_encryption_key_map; + Status status = key_pool_.Peek(crypto_period_index, + &ref_counted_encryption_key_map, + kGetKeyTimeoutInSeconds * 1000); + if (!status.ok()) { + if (status.error_code() == error::STOPPED) { + CHECK(!common_encryption_request_status_.ok()); + return common_encryption_request_status_; + } + return status; + } + + EncryptionKeyMap& encryption_key_map = ref_counted_encryption_key_map->map(); + if (encryption_key_map.find(track_type) == encryption_key_map.end()) { + return Status(error::INTERNAL_ERROR, + "Cannot find key of type " + TrackTypeToString(track_type)); + } + *key = *encryption_key_map[track_type]; + return Status::OK; +} + +void WidevineEncryptionKeySource::FetchKeysTask() { + Status status = FetchKeys(first_crypto_period_index_); + if (key_rotation_enabled_) { + while (status.ok()) { + first_crypto_period_index_ += crypto_period_count_; + status = FetchKeys(first_crypto_period_index_); + } + } + common_encryption_request_status_ = status; + key_pool_.Stop(); +} + +Status WidevineEncryptionKeySource::FetchKeys( + uint32 first_crypto_period_index) { std::string request; - FillRequest(content_id_, &request); + FillRequest(content_id_, first_crypto_period_index, &request); std::string message; Status status = SignRequest(request, &message); @@ -185,6 +259,7 @@ Status WidevineEncryptionKeySource::FetchKeys() { } void WidevineEncryptionKeySource::FillRequest(const std::string& content_id, + uint32 first_crypto_period_index, std::string* request) { DCHECK(request); @@ -215,6 +290,13 @@ void WidevineEncryptionKeySource::FillRequest(const std::string& content_id, drm_types->AppendString("WIDEVINE"); request_dict.Set("drm_types", drm_types); + // Build key rotation fields. + if (key_rotation_enabled_) { + request_dict.SetInteger("first_crypto_period_index", + first_crypto_period_index); + request_dict.SetInteger("crypto_period_count", crypto_period_count_); + } + base::JSONWriter::Write(&request_dict, request); } @@ -288,17 +370,40 @@ bool WidevineEncryptionKeySource::ExtractEncryptionKey( const base::ListValue* tracks; RCHECK(license_dict->GetList("tracks", &tracks)); - RCHECK(tracks->GetSize() >= NUM_VALID_TRACK_TYPES); + RCHECK(key_rotation_enabled_ + ? tracks->GetSize() >= NUM_VALID_TRACK_TYPES * crypto_period_count_ + : tracks->GetSize() >= NUM_VALID_TRACK_TYPES); + int current_crypto_period_index = first_crypto_period_index_; + + EncryptionKeyMap encryption_key_map; for (size_t i = 0; i < tracks->GetSize(); ++i) { const base::DictionaryValue* track_dict; RCHECK(tracks->GetDictionary(i, &track_dict)); + if (key_rotation_enabled_) { + int crypto_period_index; + RCHECK( + track_dict->GetInteger("crypto_period_index", &crypto_period_index)); + if (crypto_period_index != current_crypto_period_index) { + if (crypto_period_index != current_crypto_period_index + 1) { + LOG(ERROR) << "Expecting crypto period index " + << current_crypto_period_index << " or " + << current_crypto_period_index + 1 << "; Seen " + << crypto_period_index << " at track " << i; + return false; + } + if (!PushToKeyPool(&encryption_key_map)) + return false; + ++current_crypto_period_index; + } + } + std::string track_type_str; RCHECK(track_dict->GetString("type", &track_type_str)); TrackType track_type = GetTrackTypeFromString(track_type_str); DCHECK_NE(TRACK_TYPE_UNKNOWN, track_type); - RCHECK(encryption_key_map_.find(track_type) == encryption_key_map_.end()); + RCHECK(encryption_key_map.find(track_type) == encryption_key_map.end()); scoped_ptr encryption_key(new EncryptionKey()); std::vector pssh_data; @@ -307,7 +412,24 @@ bool WidevineEncryptionKeySource::ExtractEncryptionKey( !GetPsshData(*track_dict, &pssh_data)) return false; encryption_key->pssh = PsshBoxFromPsshData(pssh_data); - encryption_key_map_[track_type] = encryption_key.release(); + encryption_key_map[track_type] = encryption_key.release(); + } + + DCHECK(!encryption_key_map.empty()); + return PushToKeyPool(&encryption_key_map); +} + +bool WidevineEncryptionKeySource::PushToKeyPool( + EncryptionKeyMap* encryption_key_map) { + DCHECK(encryption_key_map); + Status status = + key_pool_.Push(scoped_refptr( + new RefCountedEncryptionKeyMap(encryption_key_map)), + kInfiniteTimeout); + encryption_key_map->clear(); + if (!status.ok()) { + DCHECK_EQ(error::STOPPED, status.error_code()); + return false; } return true; } diff --git a/media/base/widevine_encryption_key_source.h b/media/base/widevine_encryption_key_source.h index 87d9249835..b1e027ac52 100644 --- a/media/base/widevine_encryption_key_source.h +++ b/media/base/widevine_encryption_key_source.h @@ -11,10 +11,13 @@ #include "base/basictypes.h" #include "base/memory/scoped_ptr.h" -#include "base/synchronization/lock.h" +#include "media/base/closure_thread.h" #include "media/base/encryption_key_source.h" +#include "media/base/producer_consumer_queue.h" namespace media { +/// A negative crypto period index disables key rotation. +static const int kDisableKeyRotation = -1; class HttpFetcher; class RequestSigner; @@ -25,26 +28,47 @@ class WidevineEncryptionKeySource : public EncryptionKeySource { public: /// @param server_url is the Widevine common encryption server url. /// @param content_id the unique id identify the content to be encrypted. - /// @param signer must not be NULL. + /// @param signer signs the request message. It should not be NULL. + /// @param first_crypto_period_index indicates the starting crypto period + /// index. Set it to kDisableKeyRotation to disable key rotation. WidevineEncryptionKeySource(const std::string& server_url, const std::string& content_id, - scoped_ptr signer); + scoped_ptr signer, + int first_crypto_period_index); virtual ~WidevineEncryptionKeySource(); - /// EncryptionKeySource implementation override. + /// @name EncryptionKeySource implementation overrides. + /// @{ virtual Status GetKey(TrackType track_type, EncryptionKey* key) OVERRIDE; + virtual Status GetCryptoPeriodKey(uint32 crypto_period_index, + TrackType track_type, + EncryptionKey* key) OVERRIDE; + /// @} /// Inject an @b HttpFetcher object, mainly used for testing. /// @param http_fetcher points to the @b HttpFetcher object to be injected. void set_http_fetcher(scoped_ptr http_fetcher); private: + typedef std::map EncryptionKeyMap; + class RefCountedEncryptionKeyMap; + + // Internal routine for getting keys. + Status GetKeyInternal(uint32 crypto_period_index, + TrackType track_type, + EncryptionKey* key); + + // The closure task to fetch keys repeatedly. + void FetchKeysTask(); + // Fetch keys from server. - Status FetchKeys(); + Status FetchKeys(uint32 first_crypto_period_index); // Fill |request| with necessary fields for Widevine encryption request. // |request| should not be NULL. - void FillRequest(const std::string& content_id, std::string* request); + void FillRequest(const std::string& content_id, + uint32 first_crypto_period_index, + std::string* request); // Sign and properly format |request|. // |signed_request| should not be NULL. Return OK on success. Status SignRequest(const std::string& request, std::string* signed_request); @@ -56,6 +80,8 @@ class WidevineEncryptionKeySource : public EncryptionKeySource { // failure is because of a transient error from the server. |transient_error| // should not be NULL. bool ExtractEncryptionKey(const std::string& response, bool* transient_error); + // Push the keys to the key pool. + bool PushToKeyPool(EncryptionKeyMap* encryption_key_map); // The fetcher object used to fetch HTTP response from server. // It is initialized to a default fetcher on class initialization. @@ -65,9 +91,12 @@ class WidevineEncryptionKeySource : public EncryptionKeySource { std::string content_id_; scoped_ptr signer_; - mutable base::Lock lock_; - bool key_fetched_; // Protected by lock_; - std::map encryption_key_map_; + const bool key_rotation_enabled_; + const uint32 crypto_period_count_; + uint32 first_crypto_period_index_; + ClosureThread key_production_thread_; + ProducerConsumerQueue > key_pool_; + Status common_encryption_request_status_; DISALLOW_COPY_AND_ASSIGN(WidevineEncryptionKeySource); }; diff --git a/media/base/widevine_encryption_key_source_unittest.cc b/media/base/widevine_encryption_key_source_unittest.cc index ac1c98bd82..ec23688a02 100644 --- a/media/base/widevine_encryption_key_source_unittest.cc +++ b/media/base/widevine_encryption_key_source_unittest.cc @@ -7,6 +7,7 @@ #include "media/base/widevine_encryption_key_source.h" #include "base/base64.h" +#include "base/strings/string_number_conversions.h" #include "base/strings/stringprintf.h" #include "media/base/http_fetcher.h" #include "media/base/request_signer.h" @@ -89,6 +90,7 @@ std::string GetPsshDataFromPsshBox(const std::string& pssh_box) { using ::testing::_; using ::testing::DoAll; +using ::testing::InSequence; using ::testing::Return; using ::testing::SetArgPointee; @@ -129,9 +131,12 @@ class WidevineEncryptionKeySourceTest : public ::testing::Test { mock_http_fetcher_(new MockHttpFetcher()) {} protected: - void CreateWidevineEncryptionKeySource() { + void CreateWidevineEncryptionKeySource(int first_crypto_period_index) { widevine_encryption_key_source_.reset(new WidevineEncryptionKeySource( - kServerUrl, kContentId, mock_request_signer_.PassAs())); + kServerUrl, + kContentId, + mock_request_signer_.PassAs(), + first_crypto_period_index)); widevine_encryption_key_source_->set_http_fetcher( mock_http_fetcher_.PassAs()); } @@ -155,11 +160,11 @@ TEST_F(WidevineEncryptionKeySourceTest, GetTrackTypeFromString) { EncryptionKeySource::GetTrackTypeFromString("FOO")); } -TEST_F(WidevineEncryptionKeySourceTest, GeneratureSignatureFailure) { +TEST_F(WidevineEncryptionKeySourceTest, GenerateSignatureFailure) { EXPECT_CALL(*mock_request_signer_, GenerateSignature(_, _)) .WillOnce(Return(false)); - CreateWidevineEncryptionKeySource(); + CreateWidevineEncryptionKeySource(kDisableKeyRotation); EncryptionKey encryption_key; ASSERT_EQ(Status(error::INTERNAL_ERROR, "Signature generation failed."), widevine_encryption_key_source_->GetKey( @@ -183,7 +188,7 @@ TEST_F(WidevineEncryptionKeySourceTest, HttpPostFailure) { EXPECT_CALL(*mock_http_fetcher_, Post(kServerUrl, expected_post_data, _)) .WillOnce(Return(kMockStatus)); - CreateWidevineEncryptionKeySource(); + CreateWidevineEncryptionKeySource(kDisableKeyRotation); EncryptionKey encryption_key; ASSERT_EQ(kMockStatus, widevine_encryption_key_source_->GetKey( @@ -194,13 +199,13 @@ TEST_F(WidevineEncryptionKeySourceTest, LicenseStatusOK) { EXPECT_CALL(*mock_request_signer_, GenerateSignature(_, _)) .WillOnce(Return(true)); - std::string expected_response = base::StringPrintf( + std::string mock_response = base::StringPrintf( kHttpResponseFormat, Base64Encode(GenerateMockLicenseResponse()).c_str()); EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _)) - .WillOnce(DoAll(SetArgPointee<2>(expected_response), Return(Status::OK))); + .WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK))); - CreateWidevineEncryptionKeySource(); + CreateWidevineEncryptionKeySource(kDisableKeyRotation); EncryptionKey encryption_key; const std::string kTrackTypes[] = {"SD", "HD", "AUDIO"}; @@ -221,7 +226,7 @@ TEST_F(WidevineEncryptionKeySourceTest, RetryOnTransientError) { std::string mock_license_status = base::StringPrintf( kLicenseResponseFormat, kLicenseStatusTransientError, ""); - std::string expected_response = base::StringPrintf( + std::string mock_response = base::StringPrintf( kHttpResponseFormat, Base64Encode(mock_license_status).c_str()); std::string expected_retried_response = base::StringPrintf( @@ -229,11 +234,11 @@ TEST_F(WidevineEncryptionKeySourceTest, RetryOnTransientError) { // Retry is expected on transient error. EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _)) - .WillOnce(DoAll(SetArgPointee<2>(expected_response), Return(Status::OK))) + .WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK))) .WillOnce(DoAll(SetArgPointee<2>(expected_retried_response), Return(Status::OK))); - CreateWidevineEncryptionKeySource(); + CreateWidevineEncryptionKeySource(kDisableKeyRotation); EncryptionKey encryption_key; ASSERT_OK(widevine_encryption_key_source_->GetKey( EncryptionKeySource::TRACK_TYPE_SD, &encryption_key)); @@ -255,13 +260,131 @@ TEST_F(WidevineEncryptionKeySourceTest, NoRetryOnUnknownError) { EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _)) .WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK))); - CreateWidevineEncryptionKeySource(); + CreateWidevineEncryptionKeySource(kDisableKeyRotation); EncryptionKey encryption_key; - ASSERT_EQ( - error::SERVER_ERROR, - widevine_encryption_key_source_->GetKey( - EncryptionKeySource::TRACK_TYPE_SD, - &encryption_key).error_code()); + Status status = widevine_encryption_key_source_->GetKey( + EncryptionKeySource::TRACK_TYPE_SD, &encryption_key); + ASSERT_EQ(error::SERVER_ERROR, status.error_code()); +} + +namespace { + +const char kCryptoPeriodRequestMessageFormat[] = + "{\"content_id\":\"%s\",\"crypto_period_count\":%u,\"drm_types\":[" + "\"WIDEVINE\"],\"first_crypto_period_index\":%u,\"policy\":\"\"," + "\"tracks\":[{\"type\":\"SD\"},{\"type\":\"HD\"},{\"type\":\"AUDIO\"}]}"; + +const char kCryptoPeriodTrackFormat[] = + "{\"type\":\"%s\",\"key_id\":\"\",\"key\":" + "\"%s\",\"pssh\":[{\"drm_type\":\"WIDEVINE\",\"data\":\"\"}], " + "\"crypto_period_index\":%u}"; + +std::string GetMockKey(const std::string& track_type, uint32 index) { + return "MockKey" + track_type + "@" + base::UintToString(index); +} + +std::string GenerateMockLicenseResponse(uint32 initial_crypto_period_index, + uint32 crypto_period_count) { + const std::string kTrackTypes[] = {"SD", "HD", "AUDIO"}; + std::string tracks; + for (uint32 index = initial_crypto_period_index; + index < initial_crypto_period_index + crypto_period_count; + ++index) { + for (size_t i = 0; i < 3; ++i) { + if (!tracks.empty()) + tracks += ","; + tracks += base::StringPrintf( + kCryptoPeriodTrackFormat, + kTrackTypes[i].c_str(), + Base64Encode(GetMockKey(kTrackTypes[i], index)).c_str(), + index); + } + } + return base::StringPrintf(kLicenseResponseFormat, "OK", tracks.c_str()); +} + +} // namespace + +TEST_F(WidevineEncryptionKeySourceTest, KeyRotationTest) { + const uint32 kFirstCryptoPeriodIndex = 8; + const uint32 kCryptoPeriodCount = 10; + // Array of indexes to be checked. + const uint32 kCryptoPeriodIndexes[] = {kFirstCryptoPeriodIndex, 17, 37, + 38, 36, 39}; + // Derived from kCryptoPeriodIndexes: ceiling((39 - 8 ) / 10). + const uint32 kCryptoIterations = 4; + + // Generate expectations in sequence. + InSequence dummy; + for (uint32 i = 0; i < kCryptoIterations; ++i) { + uint32 first_crypto_period_index = + kFirstCryptoPeriodIndex + i * kCryptoPeriodCount; + std::string expected_message = + base::StringPrintf(kCryptoPeriodRequestMessageFormat, + Base64Encode(kContentId).c_str(), + kCryptoPeriodCount, + first_crypto_period_index); + EXPECT_CALL(*mock_request_signer_, GenerateSignature(expected_message, _)) + .WillOnce(DoAll(SetArgPointee<1>(kMockSignature), Return(true))); + + std::string mock_response = base::StringPrintf( + kHttpResponseFormat, + Base64Encode(GenerateMockLicenseResponse(first_crypto_period_index, + kCryptoPeriodCount)).c_str()); + EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _)) + .WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK))); + } + + CreateWidevineEncryptionKeySource(kFirstCryptoPeriodIndex); + + EncryptionKey encryption_key; + + // Index before kFirstCryptoPeriodIndex is invalid. + Status status = widevine_encryption_key_source_->GetCryptoPeriodKey( + kFirstCryptoPeriodIndex - 1, + EncryptionKeySource::TRACK_TYPE_SD, + &encryption_key); + EXPECT_EQ(error::INVALID_ARGUMENT, status.error_code()); + + for (size_t i = 0; i < arraysize(kCryptoPeriodIndexes); ++i) { + const std::string kTrackTypes[] = {"SD", "HD", "AUDIO"}; + for (size_t j = 0; j < 3; ++j) { + ASSERT_OK(widevine_encryption_key_source_->GetCryptoPeriodKey( + kCryptoPeriodIndexes[i], + EncryptionKeySource::GetTrackTypeFromString(kTrackTypes[j]), + &encryption_key)); + EXPECT_EQ(GetMockKey(kTrackTypes[j], kCryptoPeriodIndexes[i]), + ToString(encryption_key.key)); + } + } + + // The old crypto period indexes should have been garbage collected. + status = widevine_encryption_key_source_->GetCryptoPeriodKey( + kFirstCryptoPeriodIndex, + EncryptionKeySource::TRACK_TYPE_SD, + &encryption_key); + EXPECT_EQ(error::INVALID_ARGUMENT, status.error_code()); +} + +class WidevineEncryptionKeySourceDeathTest + : public WidevineEncryptionKeySourceTest {}; + +TEST_F(WidevineEncryptionKeySourceDeathTest, + GetCryptoPeriodKeyOnNonKeyRotationSource) { + CreateWidevineEncryptionKeySource(kDisableKeyRotation); + EncryptionKey encryption_key; + EXPECT_DEBUG_DEATH( + widevine_encryption_key_source_->GetCryptoPeriodKey( + 0, EncryptionKeySource::TRACK_TYPE_SD, &encryption_key), + ""); +} + +TEST_F(WidevineEncryptionKeySourceDeathTest, GetKeyOnKeyRotationSource) { + CreateWidevineEncryptionKeySource(0); + EncryptionKey encryption_key; + EXPECT_DEBUG_DEATH(widevine_encryption_key_source_->GetKey( + EncryptionKeySource::TRACK_TYPE_SD, &encryption_key), + ""); } } // namespace media