Support key rotation in widevine encryption key source

Change-Id: I05ded15fa666119c86a1d3f1c99123b9cda60b49
This commit is contained in:
Kongqun Yang 2014-04-24 09:59:07 -07:00 committed by KongQun Yang
parent 1773d08b8d
commit 1f315ba921
9 changed files with 373 additions and 75 deletions

View File

@ -76,7 +76,10 @@ scoped_ptr<EncryptionKeySource> CreateEncryptionKeySource() {
} }
encryption_key_source.reset(new WidevineEncryptionKeySource( encryption_key_source.reset(new WidevineEncryptionKeySource(
FLAGS_server_url, FLAGS_content_id, signer.Pass())); FLAGS_server_url,
FLAGS_content_id,
signer.Pass(),
FLAGS_crypto_period_duration == 0 ? kDisableKeyRotation : 0));
} else if (FLAGS_enable_fixed_key_encryption) { } else if (FLAGS_enable_fixed_key_encryption) {
encryption_key_source = EncryptionKeySource::CreateFromHexStrings( encryption_key_source = EncryptionKeySource::CreateFromHexStrings(
FLAGS_key_id, FLAGS_key, FLAGS_pssh, ""); FLAGS_key_id, FLAGS_key, FLAGS_pssh, "");

View File

@ -169,11 +169,12 @@ void AesCbcEncryptor::Encrypt(const std::string& plaintext,
padded_text.append(num_padding_bytes, static_cast<char>(num_padding_bytes)); padded_text.append(num_padding_bytes, static_cast<char>(num_padding_bytes));
ciphertext->resize(padded_text.size()); ciphertext->resize(padded_text.size());
std::vector<uint8> iv(iv_);
AES_cbc_encrypt(reinterpret_cast<const uint8*>(padded_text.data()), AES_cbc_encrypt(reinterpret_cast<const uint8*>(padded_text.data()),
reinterpret_cast<uint8*>(string_as_array(ciphertext)), reinterpret_cast<uint8*>(string_as_array(ciphertext)),
padded_text.size(), padded_text.size(),
encrypt_key_.get(), encrypt_key_.get(),
&iv_[0], &iv[0],
AES_ENCRYPT); AES_ENCRYPT);
} }

View File

@ -30,7 +30,7 @@ Status EncryptionKeySource::GetKey(TrackType track_type, EncryptionKey* key) {
return Status::OK; return Status::OK;
} }
Status EncryptionKeySource::GetCryptoPeriodKey(size_t crypto_period_index, Status EncryptionKeySource::GetCryptoPeriodKey(uint32 crypto_period_index,
TrackType track_type, TrackType track_type,
EncryptionKey* key) { EncryptionKey* key) {
NOTIMPLEMENTED(); NOTIMPLEMENTED();

View File

@ -43,7 +43,7 @@ class EncryptionKeySource {
/// Get encryption key of the specified track type at the specified index. /// Get encryption key of the specified track type at the specified index.
/// @return OK on success, an error status otherwise. /// @return OK on success, an error status otherwise.
virtual Status GetCryptoPeriodKey(size_t crypto_period_index, virtual Status GetCryptoPeriodKey(uint32 crypto_period_index,
TrackType track_type, TrackType track_type,
EncryptionKey* key); EncryptionKey* key);

View File

@ -17,16 +17,26 @@
namespace media { namespace media {
static const size_t kUnlimitedCapacity = 0u;
static const int64 kInfiniteTimeout = -1;
/// A thread safe producer consumer queue implementation. It allows the standard /// A thread safe producer consumer queue implementation. It allows the standard
/// push and pop operations. It also maintains a monotonically-increasing /// push and pop operations. It also maintains a monotonically-increasing
/// element position and allows peeking at the element at certain position. /// element position and allows peeking at the element at certain position.
template <class T> template <class T>
class ProducerConsumerQueue { class ProducerConsumerQueue {
public: public:
/// Create a ProducerConsumerQueue. /// Create a ProducerConsumerQueue starting from position 0.
/// @param capacity is the maximum number of elements that the queue can hold /// @param capacity is the maximum number of elements that the queue can hold
/// at once. A value of zero means unlimited capacity. /// at once. A value of zero means unlimited capacity.
explicit ProducerConsumerQueue(size_t capacity); explicit ProducerConsumerQueue(size_t capacity);
/// Create a ProducerConsumerQueue starting from indicated position.
/// @param capacity is the maximum number of elements that the queue can hold
/// at once. A value of zero means unlimited capacity.
/// @param starting_pos is the starting head position.
ProducerConsumerQueue(size_t capacity, size_t starting_pos);
~ProducerConsumerQueue(); ~ProducerConsumerQueue();
/// Push an element to the back of the queue. If the queue has reached its /// Push an element to the back of the queue. If the queue has reached its
@ -89,14 +99,14 @@ class ProducerConsumerQueue {
/// returned value may be meaningless if the queue is empty. /// returned value may be meaningless if the queue is empty.
size_t HeadPos() const { size_t HeadPos() const {
base::AutoLock l(lock_); base::AutoLock l(lock_);
return head_; return head_pos_;
} }
/// @return The position of the tail element in the queue. Note that the /// @return The position of the tail element in the queue. Note that the
/// returned value may be meaningless if the queue is empty. /// returned value may be meaningless if the queue is empty.
size_t TailPos() const { size_t TailPos() const {
base::AutoLock l(lock_); base::AutoLock l(lock_);
return head_ + q_.size() - 1; return head_pos_ + q_.size() - 1;
} }
/// @return true if the queue has been stopped using Stop(). This allows /// @return true if the queue has been stopped using Stop(). This allows
@ -107,12 +117,12 @@ class ProducerConsumerQueue {
} }
private: private:
// Move head_ to center on pos. // Move head_pos_ to center on pos.
void SlideHeadOnCenter(size_t pos); void SlideHeadOnCenter(size_t pos);
const size_t capacity_; // Maximum number of elements; zero means unlimited. const size_t capacity_; // Maximum number of elements; zero means unlimited.
mutable base::Lock lock_; // Lock protecting all other variables below. mutable base::Lock lock_; // Lock protecting all other variables below.
size_t head_; // Head position. size_t head_pos_; // Head position.
std::deque<T> q_; // Internal queue holding the elements. std::deque<T> q_; // Internal queue holding the elements.
base::ConditionVariable not_empty_cv_; base::ConditionVariable not_empty_cv_;
base::ConditionVariable not_full_cv_; base::ConditionVariable not_full_cv_;
@ -126,12 +136,23 @@ class ProducerConsumerQueue {
template <class T> template <class T>
ProducerConsumerQueue<T>::ProducerConsumerQueue(size_t capacity) ProducerConsumerQueue<T>::ProducerConsumerQueue(size_t capacity)
: capacity_(capacity), : capacity_(capacity),
head_(0), head_pos_(0),
not_empty_cv_(&lock_), not_empty_cv_(&lock_),
not_full_cv_(&lock_), not_full_cv_(&lock_),
new_element_cv_(&lock_), new_element_cv_(&lock_),
stop_requested_(false) {} stop_requested_(false) {}
template <class T>
ProducerConsumerQueue<T>::ProducerConsumerQueue(size_t capacity,
size_t starting_pos)
: capacity_(capacity),
head_pos_(starting_pos),
not_empty_cv_(&lock_),
not_full_cv_(&lock_),
new_element_cv_(&lock_),
stop_requested_(false) {
}
template <class T> template <class T>
ProducerConsumerQueue<T>::~ProducerConsumerQueue() {} ProducerConsumerQueue<T>::~ProducerConsumerQueue() {}
@ -218,7 +239,7 @@ Status ProducerConsumerQueue<T>::Pop(T* element, int64 timeout_ms) {
*element = q_.front(); *element = q_.front();
q_.pop_front(); q_.pop_front();
++head_; ++head_pos_;
// Signal other consumers if we have more elements. // Signal other consumers if we have more elements.
if (woken && !q_.empty()) if (woken && !q_.empty())
@ -231,10 +252,11 @@ Status ProducerConsumerQueue<T>::Peek(size_t pos,
T* element, T* element,
int64 timeout_ms) { int64 timeout_ms) {
base::AutoLock l(lock_); base::AutoLock l(lock_);
if (pos < head_) { if (pos < head_pos_) {
return Status( return Status(
error::INVALID_ARGUMENT, error::INVALID_ARGUMENT,
base::StringPrintf("pos (%zu) is too small; head is %zu.", pos, head_)); base::StringPrintf(
"pos (%zu) is too small; head is at %zu.", pos, head_pos_));
} }
bool woken = false; bool woken = false;
@ -242,10 +264,10 @@ Status ProducerConsumerQueue<T>::Peek(size_t pos,
base::ElapsedTimer timer; base::ElapsedTimer timer;
base::TimeDelta timeout_delta = base::TimeDelta::FromMilliseconds(timeout_ms); base::TimeDelta timeout_delta = base::TimeDelta::FromMilliseconds(timeout_ms);
// Move head_ to create some space (move the sliding window centered @ pos). // Move head to create some space (move the sliding window centered @ pos).
SlideHeadOnCenter(pos); SlideHeadOnCenter(pos);
while (pos >= head_ + q_.size()) { while (pos >= head_pos_ + q_.size()) {
if (stop_requested_) if (stop_requested_)
return Status(error::STOPPED, ""); return Status(error::STOPPED, "");
@ -262,12 +284,12 @@ Status ProducerConsumerQueue<T>::Peek(size_t pos,
return Status(error::TIME_OUT, "Time out on peeking."); return Status(error::TIME_OUT, "Time out on peeking.");
} }
} }
// Move head_ to create some space (move the sliding window centered @ pos). // Move head to create some space (move the sliding window centered @ pos).
SlideHeadOnCenter(pos); SlideHeadOnCenter(pos);
woken = true; woken = true;
} }
*element = q_[pos - head_]; *element = q_[pos - head_pos_];
// Signal other consumers if we have more elements. // Signal other consumers if we have more elements.
if (woken && !q_.empty()) if (woken && !q_.empty())
@ -281,11 +303,11 @@ void ProducerConsumerQueue<T>::SlideHeadOnCenter(size_t pos) {
if (capacity_) { if (capacity_) {
// Signal producer to proceed if we are going to create some capacity. // Signal producer to proceed if we are going to create some capacity.
if (q_.size() == capacity_ && pos > head_ + capacity_ / 2) if (q_.size() == capacity_ && pos > head_pos_ + capacity_ / 2)
not_full_cv_.Signal(); not_full_cv_.Signal();
while (!q_.empty() && pos > head_ + capacity_ / 2) { while (!q_.empty() && pos > head_pos_ + capacity_ / 2) {
++head_; ++head_pos_;
q_.pop_front(); q_.pop_front();
} }
} }

View File

@ -13,10 +13,8 @@
#include "testing/gtest/include/gtest/gtest.h" #include "testing/gtest/include/gtest/gtest.h"
namespace { namespace {
const size_t kUnlimitedCapacity = 0u;
const size_t kCapacity = 10u; const size_t kCapacity = 10u;
const int64 kTimeout = 100; // 0.1s. const int64 kTimeout = 100; // 0.1s.
const int64 kInfiniteTimeout = -1;
// Check that the |delta| is approximately |time_in_milliseconds|. // Check that the |delta| is approximately |time_in_milliseconds|.
bool CheckTimeApproxEqual(int64 time_in_milliseconds, bool CheckTimeApproxEqual(int64 time_in_milliseconds,

View File

@ -7,8 +7,10 @@
#include "media/base/widevine_encryption_key_source.h" #include "media/base/widevine_encryption_key_source.h"
#include "base/base64.h" #include "base/base64.h"
#include "base/bind.h"
#include "base/json/json_reader.h" #include "base/json/json_reader.h"
#include "base/json/json_writer.h" #include "base/json/json_writer.h"
#include "base/memory/ref_counted.h"
#include "base/stl_util.h" #include "base/stl_util.h"
#include "base/values.h" #include "base/values.h"
#include "media/base/http_fetcher.h" #include "media/base/http_fetcher.h"
@ -34,6 +36,11 @@ const char kLicenseStatusTransientError[] = "INTERNAL_ERROR";
const int kNumTransientErrorRetries = 5; const int kNumTransientErrorRetries = 5;
const int kFirstRetryDelayMilliseconds = 1000; const int kFirstRetryDelayMilliseconds = 1000;
// Default crypto period count, which is the number of keys to fetch on every
// key rotation enabled request.
const int kDefaultCryptoPeriodCount = 10;
const int kGetKeyTimeoutInSeconds = 5 * 60; // 5 minutes.
bool Base64StringToBytes(const std::string& base64_string, bool Base64StringToBytes(const std::string& base64_string,
std::vector<uint8>* bytes) { std::vector<uint8>* bytes) {
DCHECK(bytes); DCHECK(bytes);
@ -93,42 +100,68 @@ bool GetPsshData(const base::DictionaryValue& track_dict,
namespace media { namespace media {
// A ref counted wrapper for EncryptionKeyMap.
class WidevineEncryptionKeySource::RefCountedEncryptionKeyMap
: public base::RefCountedThreadSafe<RefCountedEncryptionKeyMap> {
public:
explicit RefCountedEncryptionKeyMap(EncryptionKeyMap* encryption_key_map) {
DCHECK(encryption_key_map);
encryption_key_map_.swap(*encryption_key_map);
}
std::map<EncryptionKeySource::TrackType, EncryptionKey*>& map() {
return encryption_key_map_;
}
private:
friend class base::RefCountedThreadSafe<RefCountedEncryptionKeyMap>;
~RefCountedEncryptionKeyMap() { STLDeleteValues(&encryption_key_map_); }
EncryptionKeyMap encryption_key_map_;
DISALLOW_COPY_AND_ASSIGN(RefCountedEncryptionKeyMap);
};
WidevineEncryptionKeySource::WidevineEncryptionKeySource( WidevineEncryptionKeySource::WidevineEncryptionKeySource(
const std::string& server_url, const std::string& server_url,
const std::string& content_id, const std::string& content_id,
scoped_ptr<RequestSigner> signer) scoped_ptr<RequestSigner> signer,
int first_crypto_period_index)
: http_fetcher_(new SimpleHttpFetcher()), : http_fetcher_(new SimpleHttpFetcher()),
server_url_(server_url), server_url_(server_url),
content_id_(content_id), content_id_(content_id),
signer_(signer.Pass()), signer_(signer.Pass()),
key_fetched_(false) { key_rotation_enabled_(first_crypto_period_index >= 0),
crypto_period_count_(kDefaultCryptoPeriodCount),
first_crypto_period_index_(first_crypto_period_index),
key_production_thread_(
"KeyProductionThread",
base::Bind(&WidevineEncryptionKeySource::FetchKeysTask,
base::Unretained(this))),
key_pool_(kDefaultCryptoPeriodCount,
key_rotation_enabled_ ? first_crypto_period_index : 0) {
DCHECK(signer_); DCHECK(signer_);
key_production_thread_.Start();
} }
WidevineEncryptionKeySource::~WidevineEncryptionKeySource() { WidevineEncryptionKeySource::~WidevineEncryptionKeySource() {
STLDeleteValues(&encryption_key_map_); key_pool_.Stop();
key_production_thread_.Join();
} }
Status WidevineEncryptionKeySource::GetKey(TrackType track_type, Status WidevineEncryptionKeySource::GetKey(TrackType track_type,
EncryptionKey* key) { EncryptionKey* key) {
DCHECK(track_type == TRACK_TYPE_SD || track_type == TRACK_TYPE_HD || DCHECK(!key_rotation_enabled_);
track_type == TRACK_TYPE_AUDIO); return GetKeyInternal(0u, track_type, key);
Status status; }
if (!key_fetched_) {
base::AutoLock auto_lock(lock_); Status WidevineEncryptionKeySource::GetCryptoPeriodKey(
if (!key_fetched_) { uint32 crypto_period_index,
status = FetchKeys(); TrackType track_type,
if (status.ok()) EncryptionKey* key) {
key_fetched_ = true; DCHECK(key_rotation_enabled_);
} return GetKeyInternal(crypto_period_index, track_type, key);
}
if (!status.ok())
return status;
if (encryption_key_map_.find(track_type) == encryption_key_map_.end()) {
return Status(error::INTERNAL_ERROR,
"Cannot find key of type " + TrackTypeToString(track_type));
}
*key = *encryption_key_map_[track_type];
return Status::OK;
} }
void WidevineEncryptionKeySource::set_http_fetcher( void WidevineEncryptionKeySource::set_http_fetcher(
@ -136,9 +169,50 @@ void WidevineEncryptionKeySource::set_http_fetcher(
http_fetcher_ = http_fetcher.Pass(); http_fetcher_ = http_fetcher.Pass();
} }
Status WidevineEncryptionKeySource::FetchKeys() { Status WidevineEncryptionKeySource::GetKeyInternal(
uint32 crypto_period_index,
TrackType track_type,
EncryptionKey* key) {
DCHECK_LE(track_type, NUM_VALID_TRACK_TYPES);
DCHECK_NE(track_type, TRACK_TYPE_UNKNOWN);
scoped_refptr<RefCountedEncryptionKeyMap> ref_counted_encryption_key_map;
Status status = key_pool_.Peek(crypto_period_index,
&ref_counted_encryption_key_map,
kGetKeyTimeoutInSeconds * 1000);
if (!status.ok()) {
if (status.error_code() == error::STOPPED) {
CHECK(!common_encryption_request_status_.ok());
return common_encryption_request_status_;
}
return status;
}
EncryptionKeyMap& encryption_key_map = ref_counted_encryption_key_map->map();
if (encryption_key_map.find(track_type) == encryption_key_map.end()) {
return Status(error::INTERNAL_ERROR,
"Cannot find key of type " + TrackTypeToString(track_type));
}
*key = *encryption_key_map[track_type];
return Status::OK;
}
void WidevineEncryptionKeySource::FetchKeysTask() {
Status status = FetchKeys(first_crypto_period_index_);
if (key_rotation_enabled_) {
while (status.ok()) {
first_crypto_period_index_ += crypto_period_count_;
status = FetchKeys(first_crypto_period_index_);
}
}
common_encryption_request_status_ = status;
key_pool_.Stop();
}
Status WidevineEncryptionKeySource::FetchKeys(
uint32 first_crypto_period_index) {
std::string request; std::string request;
FillRequest(content_id_, &request); FillRequest(content_id_, first_crypto_period_index, &request);
std::string message; std::string message;
Status status = SignRequest(request, &message); Status status = SignRequest(request, &message);
@ -185,6 +259,7 @@ Status WidevineEncryptionKeySource::FetchKeys() {
} }
void WidevineEncryptionKeySource::FillRequest(const std::string& content_id, void WidevineEncryptionKeySource::FillRequest(const std::string& content_id,
uint32 first_crypto_period_index,
std::string* request) { std::string* request) {
DCHECK(request); DCHECK(request);
@ -215,6 +290,13 @@ void WidevineEncryptionKeySource::FillRequest(const std::string& content_id,
drm_types->AppendString("WIDEVINE"); drm_types->AppendString("WIDEVINE");
request_dict.Set("drm_types", drm_types); request_dict.Set("drm_types", drm_types);
// Build key rotation fields.
if (key_rotation_enabled_) {
request_dict.SetInteger("first_crypto_period_index",
first_crypto_period_index);
request_dict.SetInteger("crypto_period_count", crypto_period_count_);
}
base::JSONWriter::Write(&request_dict, request); base::JSONWriter::Write(&request_dict, request);
} }
@ -288,17 +370,40 @@ bool WidevineEncryptionKeySource::ExtractEncryptionKey(
const base::ListValue* tracks; const base::ListValue* tracks;
RCHECK(license_dict->GetList("tracks", &tracks)); RCHECK(license_dict->GetList("tracks", &tracks));
RCHECK(tracks->GetSize() >= NUM_VALID_TRACK_TYPES); RCHECK(key_rotation_enabled_
? tracks->GetSize() >= NUM_VALID_TRACK_TYPES * crypto_period_count_
: tracks->GetSize() >= NUM_VALID_TRACK_TYPES);
int current_crypto_period_index = first_crypto_period_index_;
EncryptionKeyMap encryption_key_map;
for (size_t i = 0; i < tracks->GetSize(); ++i) { for (size_t i = 0; i < tracks->GetSize(); ++i) {
const base::DictionaryValue* track_dict; const base::DictionaryValue* track_dict;
RCHECK(tracks->GetDictionary(i, &track_dict)); RCHECK(tracks->GetDictionary(i, &track_dict));
if (key_rotation_enabled_) {
int crypto_period_index;
RCHECK(
track_dict->GetInteger("crypto_period_index", &crypto_period_index));
if (crypto_period_index != current_crypto_period_index) {
if (crypto_period_index != current_crypto_period_index + 1) {
LOG(ERROR) << "Expecting crypto period index "
<< current_crypto_period_index << " or "
<< current_crypto_period_index + 1 << "; Seen "
<< crypto_period_index << " at track " << i;
return false;
}
if (!PushToKeyPool(&encryption_key_map))
return false;
++current_crypto_period_index;
}
}
std::string track_type_str; std::string track_type_str;
RCHECK(track_dict->GetString("type", &track_type_str)); RCHECK(track_dict->GetString("type", &track_type_str));
TrackType track_type = GetTrackTypeFromString(track_type_str); TrackType track_type = GetTrackTypeFromString(track_type_str);
DCHECK_NE(TRACK_TYPE_UNKNOWN, track_type); DCHECK_NE(TRACK_TYPE_UNKNOWN, track_type);
RCHECK(encryption_key_map_.find(track_type) == encryption_key_map_.end()); RCHECK(encryption_key_map.find(track_type) == encryption_key_map.end());
scoped_ptr<EncryptionKey> encryption_key(new EncryptionKey()); scoped_ptr<EncryptionKey> encryption_key(new EncryptionKey());
std::vector<uint8> pssh_data; std::vector<uint8> pssh_data;
@ -307,7 +412,24 @@ bool WidevineEncryptionKeySource::ExtractEncryptionKey(
!GetPsshData(*track_dict, &pssh_data)) !GetPsshData(*track_dict, &pssh_data))
return false; return false;
encryption_key->pssh = PsshBoxFromPsshData(pssh_data); encryption_key->pssh = PsshBoxFromPsshData(pssh_data);
encryption_key_map_[track_type] = encryption_key.release(); encryption_key_map[track_type] = encryption_key.release();
}
DCHECK(!encryption_key_map.empty());
return PushToKeyPool(&encryption_key_map);
}
bool WidevineEncryptionKeySource::PushToKeyPool(
EncryptionKeyMap* encryption_key_map) {
DCHECK(encryption_key_map);
Status status =
key_pool_.Push(scoped_refptr<RefCountedEncryptionKeyMap>(
new RefCountedEncryptionKeyMap(encryption_key_map)),
kInfiniteTimeout);
encryption_key_map->clear();
if (!status.ok()) {
DCHECK_EQ(error::STOPPED, status.error_code());
return false;
} }
return true; return true;
} }

View File

@ -11,10 +11,13 @@
#include "base/basictypes.h" #include "base/basictypes.h"
#include "base/memory/scoped_ptr.h" #include "base/memory/scoped_ptr.h"
#include "base/synchronization/lock.h" #include "media/base/closure_thread.h"
#include "media/base/encryption_key_source.h" #include "media/base/encryption_key_source.h"
#include "media/base/producer_consumer_queue.h"
namespace media { namespace media {
/// A negative crypto period index disables key rotation.
static const int kDisableKeyRotation = -1;
class HttpFetcher; class HttpFetcher;
class RequestSigner; class RequestSigner;
@ -25,26 +28,47 @@ class WidevineEncryptionKeySource : public EncryptionKeySource {
public: public:
/// @param server_url is the Widevine common encryption server url. /// @param server_url is the Widevine common encryption server url.
/// @param content_id the unique id identify the content to be encrypted. /// @param content_id the unique id identify the content to be encrypted.
/// @param signer must not be NULL. /// @param signer signs the request message. It should not be NULL.
/// @param first_crypto_period_index indicates the starting crypto period
/// index. Set it to kDisableKeyRotation to disable key rotation.
WidevineEncryptionKeySource(const std::string& server_url, WidevineEncryptionKeySource(const std::string& server_url,
const std::string& content_id, const std::string& content_id,
scoped_ptr<RequestSigner> signer); scoped_ptr<RequestSigner> signer,
int first_crypto_period_index);
virtual ~WidevineEncryptionKeySource(); virtual ~WidevineEncryptionKeySource();
/// EncryptionKeySource implementation override. /// @name EncryptionKeySource implementation overrides.
/// @{
virtual Status GetKey(TrackType track_type, EncryptionKey* key) OVERRIDE; virtual Status GetKey(TrackType track_type, EncryptionKey* key) OVERRIDE;
virtual Status GetCryptoPeriodKey(uint32 crypto_period_index,
TrackType track_type,
EncryptionKey* key) OVERRIDE;
/// @}
/// Inject an @b HttpFetcher object, mainly used for testing. /// Inject an @b HttpFetcher object, mainly used for testing.
/// @param http_fetcher points to the @b HttpFetcher object to be injected. /// @param http_fetcher points to the @b HttpFetcher object to be injected.
void set_http_fetcher(scoped_ptr<HttpFetcher> http_fetcher); void set_http_fetcher(scoped_ptr<HttpFetcher> http_fetcher);
private: private:
typedef std::map<TrackType, EncryptionKey*> EncryptionKeyMap;
class RefCountedEncryptionKeyMap;
// Internal routine for getting keys.
Status GetKeyInternal(uint32 crypto_period_index,
TrackType track_type,
EncryptionKey* key);
// The closure task to fetch keys repeatedly.
void FetchKeysTask();
// Fetch keys from server. // Fetch keys from server.
Status FetchKeys(); Status FetchKeys(uint32 first_crypto_period_index);
// Fill |request| with necessary fields for Widevine encryption request. // Fill |request| with necessary fields for Widevine encryption request.
// |request| should not be NULL. // |request| should not be NULL.
void FillRequest(const std::string& content_id, std::string* request); void FillRequest(const std::string& content_id,
uint32 first_crypto_period_index,
std::string* request);
// Sign and properly format |request|. // Sign and properly format |request|.
// |signed_request| should not be NULL. Return OK on success. // |signed_request| should not be NULL. Return OK on success.
Status SignRequest(const std::string& request, std::string* signed_request); Status SignRequest(const std::string& request, std::string* signed_request);
@ -56,6 +80,8 @@ class WidevineEncryptionKeySource : public EncryptionKeySource {
// failure is because of a transient error from the server. |transient_error| // failure is because of a transient error from the server. |transient_error|
// should not be NULL. // should not be NULL.
bool ExtractEncryptionKey(const std::string& response, bool* transient_error); bool ExtractEncryptionKey(const std::string& response, bool* transient_error);
// Push the keys to the key pool.
bool PushToKeyPool(EncryptionKeyMap* encryption_key_map);
// The fetcher object used to fetch HTTP response from server. // The fetcher object used to fetch HTTP response from server.
// It is initialized to a default fetcher on class initialization. // It is initialized to a default fetcher on class initialization.
@ -65,9 +91,12 @@ class WidevineEncryptionKeySource : public EncryptionKeySource {
std::string content_id_; std::string content_id_;
scoped_ptr<RequestSigner> signer_; scoped_ptr<RequestSigner> signer_;
mutable base::Lock lock_; const bool key_rotation_enabled_;
bool key_fetched_; // Protected by lock_; const uint32 crypto_period_count_;
std::map<TrackType, EncryptionKey*> encryption_key_map_; uint32 first_crypto_period_index_;
ClosureThread key_production_thread_;
ProducerConsumerQueue<scoped_refptr<RefCountedEncryptionKeyMap> > key_pool_;
Status common_encryption_request_status_;
DISALLOW_COPY_AND_ASSIGN(WidevineEncryptionKeySource); DISALLOW_COPY_AND_ASSIGN(WidevineEncryptionKeySource);
}; };

View File

@ -7,6 +7,7 @@
#include "media/base/widevine_encryption_key_source.h" #include "media/base/widevine_encryption_key_source.h"
#include "base/base64.h" #include "base/base64.h"
#include "base/strings/string_number_conversions.h"
#include "base/strings/stringprintf.h" #include "base/strings/stringprintf.h"
#include "media/base/http_fetcher.h" #include "media/base/http_fetcher.h"
#include "media/base/request_signer.h" #include "media/base/request_signer.h"
@ -89,6 +90,7 @@ std::string GetPsshDataFromPsshBox(const std::string& pssh_box) {
using ::testing::_; using ::testing::_;
using ::testing::DoAll; using ::testing::DoAll;
using ::testing::InSequence;
using ::testing::Return; using ::testing::Return;
using ::testing::SetArgPointee; using ::testing::SetArgPointee;
@ -129,9 +131,12 @@ class WidevineEncryptionKeySourceTest : public ::testing::Test {
mock_http_fetcher_(new MockHttpFetcher()) {} mock_http_fetcher_(new MockHttpFetcher()) {}
protected: protected:
void CreateWidevineEncryptionKeySource() { void CreateWidevineEncryptionKeySource(int first_crypto_period_index) {
widevine_encryption_key_source_.reset(new WidevineEncryptionKeySource( widevine_encryption_key_source_.reset(new WidevineEncryptionKeySource(
kServerUrl, kContentId, mock_request_signer_.PassAs<RequestSigner>())); kServerUrl,
kContentId,
mock_request_signer_.PassAs<RequestSigner>(),
first_crypto_period_index));
widevine_encryption_key_source_->set_http_fetcher( widevine_encryption_key_source_->set_http_fetcher(
mock_http_fetcher_.PassAs<HttpFetcher>()); mock_http_fetcher_.PassAs<HttpFetcher>());
} }
@ -155,11 +160,11 @@ TEST_F(WidevineEncryptionKeySourceTest, GetTrackTypeFromString) {
EncryptionKeySource::GetTrackTypeFromString("FOO")); EncryptionKeySource::GetTrackTypeFromString("FOO"));
} }
TEST_F(WidevineEncryptionKeySourceTest, GeneratureSignatureFailure) { TEST_F(WidevineEncryptionKeySourceTest, GenerateSignatureFailure) {
EXPECT_CALL(*mock_request_signer_, GenerateSignature(_, _)) EXPECT_CALL(*mock_request_signer_, GenerateSignature(_, _))
.WillOnce(Return(false)); .WillOnce(Return(false));
CreateWidevineEncryptionKeySource(); CreateWidevineEncryptionKeySource(kDisableKeyRotation);
EncryptionKey encryption_key; EncryptionKey encryption_key;
ASSERT_EQ(Status(error::INTERNAL_ERROR, "Signature generation failed."), ASSERT_EQ(Status(error::INTERNAL_ERROR, "Signature generation failed."),
widevine_encryption_key_source_->GetKey( widevine_encryption_key_source_->GetKey(
@ -183,7 +188,7 @@ TEST_F(WidevineEncryptionKeySourceTest, HttpPostFailure) {
EXPECT_CALL(*mock_http_fetcher_, Post(kServerUrl, expected_post_data, _)) EXPECT_CALL(*mock_http_fetcher_, Post(kServerUrl, expected_post_data, _))
.WillOnce(Return(kMockStatus)); .WillOnce(Return(kMockStatus));
CreateWidevineEncryptionKeySource(); CreateWidevineEncryptionKeySource(kDisableKeyRotation);
EncryptionKey encryption_key; EncryptionKey encryption_key;
ASSERT_EQ(kMockStatus, ASSERT_EQ(kMockStatus,
widevine_encryption_key_source_->GetKey( widevine_encryption_key_source_->GetKey(
@ -194,13 +199,13 @@ TEST_F(WidevineEncryptionKeySourceTest, LicenseStatusOK) {
EXPECT_CALL(*mock_request_signer_, GenerateSignature(_, _)) EXPECT_CALL(*mock_request_signer_, GenerateSignature(_, _))
.WillOnce(Return(true)); .WillOnce(Return(true));
std::string expected_response = base::StringPrintf( std::string mock_response = base::StringPrintf(
kHttpResponseFormat, Base64Encode(GenerateMockLicenseResponse()).c_str()); kHttpResponseFormat, Base64Encode(GenerateMockLicenseResponse()).c_str());
EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _)) EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _))
.WillOnce(DoAll(SetArgPointee<2>(expected_response), Return(Status::OK))); .WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK)));
CreateWidevineEncryptionKeySource(); CreateWidevineEncryptionKeySource(kDisableKeyRotation);
EncryptionKey encryption_key; EncryptionKey encryption_key;
const std::string kTrackTypes[] = {"SD", "HD", "AUDIO"}; const std::string kTrackTypes[] = {"SD", "HD", "AUDIO"};
@ -221,7 +226,7 @@ TEST_F(WidevineEncryptionKeySourceTest, RetryOnTransientError) {
std::string mock_license_status = base::StringPrintf( std::string mock_license_status = base::StringPrintf(
kLicenseResponseFormat, kLicenseStatusTransientError, ""); kLicenseResponseFormat, kLicenseStatusTransientError, "");
std::string expected_response = base::StringPrintf( std::string mock_response = base::StringPrintf(
kHttpResponseFormat, Base64Encode(mock_license_status).c_str()); kHttpResponseFormat, Base64Encode(mock_license_status).c_str());
std::string expected_retried_response = base::StringPrintf( std::string expected_retried_response = base::StringPrintf(
@ -229,11 +234,11 @@ TEST_F(WidevineEncryptionKeySourceTest, RetryOnTransientError) {
// Retry is expected on transient error. // Retry is expected on transient error.
EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _)) EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _))
.WillOnce(DoAll(SetArgPointee<2>(expected_response), Return(Status::OK))) .WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK)))
.WillOnce(DoAll(SetArgPointee<2>(expected_retried_response), .WillOnce(DoAll(SetArgPointee<2>(expected_retried_response),
Return(Status::OK))); Return(Status::OK)));
CreateWidevineEncryptionKeySource(); CreateWidevineEncryptionKeySource(kDisableKeyRotation);
EncryptionKey encryption_key; EncryptionKey encryption_key;
ASSERT_OK(widevine_encryption_key_source_->GetKey( ASSERT_OK(widevine_encryption_key_source_->GetKey(
EncryptionKeySource::TRACK_TYPE_SD, &encryption_key)); EncryptionKeySource::TRACK_TYPE_SD, &encryption_key));
@ -255,13 +260,131 @@ TEST_F(WidevineEncryptionKeySourceTest, NoRetryOnUnknownError) {
EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _)) EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _))
.WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK))); .WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK)));
CreateWidevineEncryptionKeySource(); CreateWidevineEncryptionKeySource(kDisableKeyRotation);
EncryptionKey encryption_key; EncryptionKey encryption_key;
ASSERT_EQ( Status status = widevine_encryption_key_source_->GetKey(
error::SERVER_ERROR, EncryptionKeySource::TRACK_TYPE_SD, &encryption_key);
widevine_encryption_key_source_->GetKey( ASSERT_EQ(error::SERVER_ERROR, status.error_code());
EncryptionKeySource::TRACK_TYPE_SD, }
&encryption_key).error_code());
namespace {
const char kCryptoPeriodRequestMessageFormat[] =
"{\"content_id\":\"%s\",\"crypto_period_count\":%u,\"drm_types\":["
"\"WIDEVINE\"],\"first_crypto_period_index\":%u,\"policy\":\"\","
"\"tracks\":[{\"type\":\"SD\"},{\"type\":\"HD\"},{\"type\":\"AUDIO\"}]}";
const char kCryptoPeriodTrackFormat[] =
"{\"type\":\"%s\",\"key_id\":\"\",\"key\":"
"\"%s\",\"pssh\":[{\"drm_type\":\"WIDEVINE\",\"data\":\"\"}], "
"\"crypto_period_index\":%u}";
std::string GetMockKey(const std::string& track_type, uint32 index) {
return "MockKey" + track_type + "@" + base::UintToString(index);
}
std::string GenerateMockLicenseResponse(uint32 initial_crypto_period_index,
uint32 crypto_period_count) {
const std::string kTrackTypes[] = {"SD", "HD", "AUDIO"};
std::string tracks;
for (uint32 index = initial_crypto_period_index;
index < initial_crypto_period_index + crypto_period_count;
++index) {
for (size_t i = 0; i < 3; ++i) {
if (!tracks.empty())
tracks += ",";
tracks += base::StringPrintf(
kCryptoPeriodTrackFormat,
kTrackTypes[i].c_str(),
Base64Encode(GetMockKey(kTrackTypes[i], index)).c_str(),
index);
}
}
return base::StringPrintf(kLicenseResponseFormat, "OK", tracks.c_str());
}
} // namespace
TEST_F(WidevineEncryptionKeySourceTest, KeyRotationTest) {
const uint32 kFirstCryptoPeriodIndex = 8;
const uint32 kCryptoPeriodCount = 10;
// Array of indexes to be checked.
const uint32 kCryptoPeriodIndexes[] = {kFirstCryptoPeriodIndex, 17, 37,
38, 36, 39};
// Derived from kCryptoPeriodIndexes: ceiling((39 - 8 ) / 10).
const uint32 kCryptoIterations = 4;
// Generate expectations in sequence.
InSequence dummy;
for (uint32 i = 0; i < kCryptoIterations; ++i) {
uint32 first_crypto_period_index =
kFirstCryptoPeriodIndex + i * kCryptoPeriodCount;
std::string expected_message =
base::StringPrintf(kCryptoPeriodRequestMessageFormat,
Base64Encode(kContentId).c_str(),
kCryptoPeriodCount,
first_crypto_period_index);
EXPECT_CALL(*mock_request_signer_, GenerateSignature(expected_message, _))
.WillOnce(DoAll(SetArgPointee<1>(kMockSignature), Return(true)));
std::string mock_response = base::StringPrintf(
kHttpResponseFormat,
Base64Encode(GenerateMockLicenseResponse(first_crypto_period_index,
kCryptoPeriodCount)).c_str());
EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _))
.WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK)));
}
CreateWidevineEncryptionKeySource(kFirstCryptoPeriodIndex);
EncryptionKey encryption_key;
// Index before kFirstCryptoPeriodIndex is invalid.
Status status = widevine_encryption_key_source_->GetCryptoPeriodKey(
kFirstCryptoPeriodIndex - 1,
EncryptionKeySource::TRACK_TYPE_SD,
&encryption_key);
EXPECT_EQ(error::INVALID_ARGUMENT, status.error_code());
for (size_t i = 0; i < arraysize(kCryptoPeriodIndexes); ++i) {
const std::string kTrackTypes[] = {"SD", "HD", "AUDIO"};
for (size_t j = 0; j < 3; ++j) {
ASSERT_OK(widevine_encryption_key_source_->GetCryptoPeriodKey(
kCryptoPeriodIndexes[i],
EncryptionKeySource::GetTrackTypeFromString(kTrackTypes[j]),
&encryption_key));
EXPECT_EQ(GetMockKey(kTrackTypes[j], kCryptoPeriodIndexes[i]),
ToString(encryption_key.key));
}
}
// The old crypto period indexes should have been garbage collected.
status = widevine_encryption_key_source_->GetCryptoPeriodKey(
kFirstCryptoPeriodIndex,
EncryptionKeySource::TRACK_TYPE_SD,
&encryption_key);
EXPECT_EQ(error::INVALID_ARGUMENT, status.error_code());
}
class WidevineEncryptionKeySourceDeathTest
: public WidevineEncryptionKeySourceTest {};
TEST_F(WidevineEncryptionKeySourceDeathTest,
GetCryptoPeriodKeyOnNonKeyRotationSource) {
CreateWidevineEncryptionKeySource(kDisableKeyRotation);
EncryptionKey encryption_key;
EXPECT_DEBUG_DEATH(
widevine_encryption_key_source_->GetCryptoPeriodKey(
0, EncryptionKeySource::TRACK_TYPE_SD, &encryption_key),
"");
}
TEST_F(WidevineEncryptionKeySourceDeathTest, GetKeyOnKeyRotationSource) {
CreateWidevineEncryptionKeySource(0);
EncryptionKey encryption_key;
EXPECT_DEBUG_DEATH(widevine_encryption_key_source_->GetKey(
EncryptionKeySource::TRACK_TYPE_SD, &encryption_key),
"");
} }
} // namespace media } // namespace media