Support key rotation in widevine encryption key source
Change-Id: I05ded15fa666119c86a1d3f1c99123b9cda60b49
This commit is contained in:
parent
1773d08b8d
commit
1f315ba921
|
@ -76,7 +76,10 @@ scoped_ptr<EncryptionKeySource> CreateEncryptionKeySource() {
|
|||
}
|
||||
|
||||
encryption_key_source.reset(new WidevineEncryptionKeySource(
|
||||
FLAGS_server_url, FLAGS_content_id, signer.Pass()));
|
||||
FLAGS_server_url,
|
||||
FLAGS_content_id,
|
||||
signer.Pass(),
|
||||
FLAGS_crypto_period_duration == 0 ? kDisableKeyRotation : 0));
|
||||
} else if (FLAGS_enable_fixed_key_encryption) {
|
||||
encryption_key_source = EncryptionKeySource::CreateFromHexStrings(
|
||||
FLAGS_key_id, FLAGS_key, FLAGS_pssh, "");
|
||||
|
|
|
@ -169,11 +169,12 @@ void AesCbcEncryptor::Encrypt(const std::string& plaintext,
|
|||
padded_text.append(num_padding_bytes, static_cast<char>(num_padding_bytes));
|
||||
|
||||
ciphertext->resize(padded_text.size());
|
||||
std::vector<uint8> iv(iv_);
|
||||
AES_cbc_encrypt(reinterpret_cast<const uint8*>(padded_text.data()),
|
||||
reinterpret_cast<uint8*>(string_as_array(ciphertext)),
|
||||
padded_text.size(),
|
||||
encrypt_key_.get(),
|
||||
&iv_[0],
|
||||
&iv[0],
|
||||
AES_ENCRYPT);
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ Status EncryptionKeySource::GetKey(TrackType track_type, EncryptionKey* key) {
|
|||
return Status::OK;
|
||||
}
|
||||
|
||||
Status EncryptionKeySource::GetCryptoPeriodKey(size_t crypto_period_index,
|
||||
Status EncryptionKeySource::GetCryptoPeriodKey(uint32 crypto_period_index,
|
||||
TrackType track_type,
|
||||
EncryptionKey* key) {
|
||||
NOTIMPLEMENTED();
|
||||
|
|
|
@ -43,7 +43,7 @@ class EncryptionKeySource {
|
|||
|
||||
/// Get encryption key of the specified track type at the specified index.
|
||||
/// @return OK on success, an error status otherwise.
|
||||
virtual Status GetCryptoPeriodKey(size_t crypto_period_index,
|
||||
virtual Status GetCryptoPeriodKey(uint32 crypto_period_index,
|
||||
TrackType track_type,
|
||||
EncryptionKey* key);
|
||||
|
||||
|
|
|
@ -17,16 +17,26 @@
|
|||
|
||||
namespace media {
|
||||
|
||||
static const size_t kUnlimitedCapacity = 0u;
|
||||
static const int64 kInfiniteTimeout = -1;
|
||||
|
||||
/// A thread safe producer consumer queue implementation. It allows the standard
|
||||
/// push and pop operations. It also maintains a monotonically-increasing
|
||||
/// element position and allows peeking at the element at certain position.
|
||||
template <class T>
|
||||
class ProducerConsumerQueue {
|
||||
public:
|
||||
/// Create a ProducerConsumerQueue.
|
||||
/// Create a ProducerConsumerQueue starting from position 0.
|
||||
/// @param capacity is the maximum number of elements that the queue can hold
|
||||
/// at once. A value of zero means unlimited capacity.
|
||||
explicit ProducerConsumerQueue(size_t capacity);
|
||||
|
||||
/// Create a ProducerConsumerQueue starting from indicated position.
|
||||
/// @param capacity is the maximum number of elements that the queue can hold
|
||||
/// at once. A value of zero means unlimited capacity.
|
||||
/// @param starting_pos is the starting head position.
|
||||
ProducerConsumerQueue(size_t capacity, size_t starting_pos);
|
||||
|
||||
~ProducerConsumerQueue();
|
||||
|
||||
/// Push an element to the back of the queue. If the queue has reached its
|
||||
|
@ -89,14 +99,14 @@ class ProducerConsumerQueue {
|
|||
/// returned value may be meaningless if the queue is empty.
|
||||
size_t HeadPos() const {
|
||||
base::AutoLock l(lock_);
|
||||
return head_;
|
||||
return head_pos_;
|
||||
}
|
||||
|
||||
/// @return The position of the tail element in the queue. Note that the
|
||||
/// returned value may be meaningless if the queue is empty.
|
||||
size_t TailPos() const {
|
||||
base::AutoLock l(lock_);
|
||||
return head_ + q_.size() - 1;
|
||||
return head_pos_ + q_.size() - 1;
|
||||
}
|
||||
|
||||
/// @return true if the queue has been stopped using Stop(). This allows
|
||||
|
@ -107,12 +117,12 @@ class ProducerConsumerQueue {
|
|||
}
|
||||
|
||||
private:
|
||||
// Move head_ to center on pos.
|
||||
// Move head_pos_ to center on pos.
|
||||
void SlideHeadOnCenter(size_t pos);
|
||||
|
||||
const size_t capacity_; // Maximum number of elements; zero means unlimited.
|
||||
mutable base::Lock lock_; // Lock protecting all other variables below.
|
||||
size_t head_; // Head position.
|
||||
size_t head_pos_; // Head position.
|
||||
std::deque<T> q_; // Internal queue holding the elements.
|
||||
base::ConditionVariable not_empty_cv_;
|
||||
base::ConditionVariable not_full_cv_;
|
||||
|
@ -126,12 +136,23 @@ class ProducerConsumerQueue {
|
|||
template <class T>
|
||||
ProducerConsumerQueue<T>::ProducerConsumerQueue(size_t capacity)
|
||||
: capacity_(capacity),
|
||||
head_(0),
|
||||
head_pos_(0),
|
||||
not_empty_cv_(&lock_),
|
||||
not_full_cv_(&lock_),
|
||||
new_element_cv_(&lock_),
|
||||
stop_requested_(false) {}
|
||||
|
||||
template <class T>
|
||||
ProducerConsumerQueue<T>::ProducerConsumerQueue(size_t capacity,
|
||||
size_t starting_pos)
|
||||
: capacity_(capacity),
|
||||
head_pos_(starting_pos),
|
||||
not_empty_cv_(&lock_),
|
||||
not_full_cv_(&lock_),
|
||||
new_element_cv_(&lock_),
|
||||
stop_requested_(false) {
|
||||
}
|
||||
|
||||
template <class T>
|
||||
ProducerConsumerQueue<T>::~ProducerConsumerQueue() {}
|
||||
|
||||
|
@ -218,7 +239,7 @@ Status ProducerConsumerQueue<T>::Pop(T* element, int64 timeout_ms) {
|
|||
|
||||
*element = q_.front();
|
||||
q_.pop_front();
|
||||
++head_;
|
||||
++head_pos_;
|
||||
|
||||
// Signal other consumers if we have more elements.
|
||||
if (woken && !q_.empty())
|
||||
|
@ -231,10 +252,11 @@ Status ProducerConsumerQueue<T>::Peek(size_t pos,
|
|||
T* element,
|
||||
int64 timeout_ms) {
|
||||
base::AutoLock l(lock_);
|
||||
if (pos < head_) {
|
||||
if (pos < head_pos_) {
|
||||
return Status(
|
||||
error::INVALID_ARGUMENT,
|
||||
base::StringPrintf("pos (%zu) is too small; head is %zu.", pos, head_));
|
||||
base::StringPrintf(
|
||||
"pos (%zu) is too small; head is at %zu.", pos, head_pos_));
|
||||
}
|
||||
|
||||
bool woken = false;
|
||||
|
@ -242,10 +264,10 @@ Status ProducerConsumerQueue<T>::Peek(size_t pos,
|
|||
base::ElapsedTimer timer;
|
||||
base::TimeDelta timeout_delta = base::TimeDelta::FromMilliseconds(timeout_ms);
|
||||
|
||||
// Move head_ to create some space (move the sliding window centered @ pos).
|
||||
// Move head to create some space (move the sliding window centered @ pos).
|
||||
SlideHeadOnCenter(pos);
|
||||
|
||||
while (pos >= head_ + q_.size()) {
|
||||
while (pos >= head_pos_ + q_.size()) {
|
||||
if (stop_requested_)
|
||||
return Status(error::STOPPED, "");
|
||||
|
||||
|
@ -262,12 +284,12 @@ Status ProducerConsumerQueue<T>::Peek(size_t pos,
|
|||
return Status(error::TIME_OUT, "Time out on peeking.");
|
||||
}
|
||||
}
|
||||
// Move head_ to create some space (move the sliding window centered @ pos).
|
||||
// Move head to create some space (move the sliding window centered @ pos).
|
||||
SlideHeadOnCenter(pos);
|
||||
woken = true;
|
||||
}
|
||||
|
||||
*element = q_[pos - head_];
|
||||
*element = q_[pos - head_pos_];
|
||||
|
||||
// Signal other consumers if we have more elements.
|
||||
if (woken && !q_.empty())
|
||||
|
@ -281,11 +303,11 @@ void ProducerConsumerQueue<T>::SlideHeadOnCenter(size_t pos) {
|
|||
|
||||
if (capacity_) {
|
||||
// Signal producer to proceed if we are going to create some capacity.
|
||||
if (q_.size() == capacity_ && pos > head_ + capacity_ / 2)
|
||||
if (q_.size() == capacity_ && pos > head_pos_ + capacity_ / 2)
|
||||
not_full_cv_.Signal();
|
||||
|
||||
while (!q_.empty() && pos > head_ + capacity_ / 2) {
|
||||
++head_;
|
||||
while (!q_.empty() && pos > head_pos_ + capacity_ / 2) {
|
||||
++head_pos_;
|
||||
q_.pop_front();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,10 +13,8 @@
|
|||
#include "testing/gtest/include/gtest/gtest.h"
|
||||
|
||||
namespace {
|
||||
const size_t kUnlimitedCapacity = 0u;
|
||||
const size_t kCapacity = 10u;
|
||||
const int64 kTimeout = 100; // 0.1s.
|
||||
const int64 kInfiniteTimeout = -1;
|
||||
|
||||
// Check that the |delta| is approximately |time_in_milliseconds|.
|
||||
bool CheckTimeApproxEqual(int64 time_in_milliseconds,
|
||||
|
|
|
@ -7,8 +7,10 @@
|
|||
#include "media/base/widevine_encryption_key_source.h"
|
||||
|
||||
#include "base/base64.h"
|
||||
#include "base/bind.h"
|
||||
#include "base/json/json_reader.h"
|
||||
#include "base/json/json_writer.h"
|
||||
#include "base/memory/ref_counted.h"
|
||||
#include "base/stl_util.h"
|
||||
#include "base/values.h"
|
||||
#include "media/base/http_fetcher.h"
|
||||
|
@ -34,6 +36,11 @@ const char kLicenseStatusTransientError[] = "INTERNAL_ERROR";
|
|||
const int kNumTransientErrorRetries = 5;
|
||||
const int kFirstRetryDelayMilliseconds = 1000;
|
||||
|
||||
// Default crypto period count, which is the number of keys to fetch on every
|
||||
// key rotation enabled request.
|
||||
const int kDefaultCryptoPeriodCount = 10;
|
||||
const int kGetKeyTimeoutInSeconds = 5 * 60; // 5 minutes.
|
||||
|
||||
bool Base64StringToBytes(const std::string& base64_string,
|
||||
std::vector<uint8>* bytes) {
|
||||
DCHECK(bytes);
|
||||
|
@ -93,42 +100,68 @@ bool GetPsshData(const base::DictionaryValue& track_dict,
|
|||
|
||||
namespace media {
|
||||
|
||||
// A ref counted wrapper for EncryptionKeyMap.
|
||||
class WidevineEncryptionKeySource::RefCountedEncryptionKeyMap
|
||||
: public base::RefCountedThreadSafe<RefCountedEncryptionKeyMap> {
|
||||
public:
|
||||
explicit RefCountedEncryptionKeyMap(EncryptionKeyMap* encryption_key_map) {
|
||||
DCHECK(encryption_key_map);
|
||||
encryption_key_map_.swap(*encryption_key_map);
|
||||
}
|
||||
|
||||
std::map<EncryptionKeySource::TrackType, EncryptionKey*>& map() {
|
||||
return encryption_key_map_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class base::RefCountedThreadSafe<RefCountedEncryptionKeyMap>;
|
||||
|
||||
~RefCountedEncryptionKeyMap() { STLDeleteValues(&encryption_key_map_); }
|
||||
|
||||
EncryptionKeyMap encryption_key_map_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(RefCountedEncryptionKeyMap);
|
||||
};
|
||||
|
||||
WidevineEncryptionKeySource::WidevineEncryptionKeySource(
|
||||
const std::string& server_url,
|
||||
const std::string& content_id,
|
||||
scoped_ptr<RequestSigner> signer)
|
||||
scoped_ptr<RequestSigner> signer,
|
||||
int first_crypto_period_index)
|
||||
: http_fetcher_(new SimpleHttpFetcher()),
|
||||
server_url_(server_url),
|
||||
content_id_(content_id),
|
||||
signer_(signer.Pass()),
|
||||
key_fetched_(false) {
|
||||
key_rotation_enabled_(first_crypto_period_index >= 0),
|
||||
crypto_period_count_(kDefaultCryptoPeriodCount),
|
||||
first_crypto_period_index_(first_crypto_period_index),
|
||||
key_production_thread_(
|
||||
"KeyProductionThread",
|
||||
base::Bind(&WidevineEncryptionKeySource::FetchKeysTask,
|
||||
base::Unretained(this))),
|
||||
key_pool_(kDefaultCryptoPeriodCount,
|
||||
key_rotation_enabled_ ? first_crypto_period_index : 0) {
|
||||
DCHECK(signer_);
|
||||
key_production_thread_.Start();
|
||||
}
|
||||
|
||||
WidevineEncryptionKeySource::~WidevineEncryptionKeySource() {
|
||||
STLDeleteValues(&encryption_key_map_);
|
||||
key_pool_.Stop();
|
||||
key_production_thread_.Join();
|
||||
}
|
||||
|
||||
Status WidevineEncryptionKeySource::GetKey(TrackType track_type,
|
||||
EncryptionKey* key) {
|
||||
DCHECK(track_type == TRACK_TYPE_SD || track_type == TRACK_TYPE_HD ||
|
||||
track_type == TRACK_TYPE_AUDIO);
|
||||
Status status;
|
||||
if (!key_fetched_) {
|
||||
base::AutoLock auto_lock(lock_);
|
||||
if (!key_fetched_) {
|
||||
status = FetchKeys();
|
||||
if (status.ok())
|
||||
key_fetched_ = true;
|
||||
}
|
||||
}
|
||||
if (!status.ok())
|
||||
return status;
|
||||
if (encryption_key_map_.find(track_type) == encryption_key_map_.end()) {
|
||||
return Status(error::INTERNAL_ERROR,
|
||||
"Cannot find key of type " + TrackTypeToString(track_type));
|
||||
}
|
||||
*key = *encryption_key_map_[track_type];
|
||||
return Status::OK;
|
||||
DCHECK(!key_rotation_enabled_);
|
||||
return GetKeyInternal(0u, track_type, key);
|
||||
}
|
||||
|
||||
Status WidevineEncryptionKeySource::GetCryptoPeriodKey(
|
||||
uint32 crypto_period_index,
|
||||
TrackType track_type,
|
||||
EncryptionKey* key) {
|
||||
DCHECK(key_rotation_enabled_);
|
||||
return GetKeyInternal(crypto_period_index, track_type, key);
|
||||
}
|
||||
|
||||
void WidevineEncryptionKeySource::set_http_fetcher(
|
||||
|
@ -136,9 +169,50 @@ void WidevineEncryptionKeySource::set_http_fetcher(
|
|||
http_fetcher_ = http_fetcher.Pass();
|
||||
}
|
||||
|
||||
Status WidevineEncryptionKeySource::FetchKeys() {
|
||||
Status WidevineEncryptionKeySource::GetKeyInternal(
|
||||
uint32 crypto_period_index,
|
||||
TrackType track_type,
|
||||
EncryptionKey* key) {
|
||||
DCHECK_LE(track_type, NUM_VALID_TRACK_TYPES);
|
||||
DCHECK_NE(track_type, TRACK_TYPE_UNKNOWN);
|
||||
|
||||
scoped_refptr<RefCountedEncryptionKeyMap> ref_counted_encryption_key_map;
|
||||
Status status = key_pool_.Peek(crypto_period_index,
|
||||
&ref_counted_encryption_key_map,
|
||||
kGetKeyTimeoutInSeconds * 1000);
|
||||
if (!status.ok()) {
|
||||
if (status.error_code() == error::STOPPED) {
|
||||
CHECK(!common_encryption_request_status_.ok());
|
||||
return common_encryption_request_status_;
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
EncryptionKeyMap& encryption_key_map = ref_counted_encryption_key_map->map();
|
||||
if (encryption_key_map.find(track_type) == encryption_key_map.end()) {
|
||||
return Status(error::INTERNAL_ERROR,
|
||||
"Cannot find key of type " + TrackTypeToString(track_type));
|
||||
}
|
||||
*key = *encryption_key_map[track_type];
|
||||
return Status::OK;
|
||||
}
|
||||
|
||||
void WidevineEncryptionKeySource::FetchKeysTask() {
|
||||
Status status = FetchKeys(first_crypto_period_index_);
|
||||
if (key_rotation_enabled_) {
|
||||
while (status.ok()) {
|
||||
first_crypto_period_index_ += crypto_period_count_;
|
||||
status = FetchKeys(first_crypto_period_index_);
|
||||
}
|
||||
}
|
||||
common_encryption_request_status_ = status;
|
||||
key_pool_.Stop();
|
||||
}
|
||||
|
||||
Status WidevineEncryptionKeySource::FetchKeys(
|
||||
uint32 first_crypto_period_index) {
|
||||
std::string request;
|
||||
FillRequest(content_id_, &request);
|
||||
FillRequest(content_id_, first_crypto_period_index, &request);
|
||||
|
||||
std::string message;
|
||||
Status status = SignRequest(request, &message);
|
||||
|
@ -185,6 +259,7 @@ Status WidevineEncryptionKeySource::FetchKeys() {
|
|||
}
|
||||
|
||||
void WidevineEncryptionKeySource::FillRequest(const std::string& content_id,
|
||||
uint32 first_crypto_period_index,
|
||||
std::string* request) {
|
||||
DCHECK(request);
|
||||
|
||||
|
@ -215,6 +290,13 @@ void WidevineEncryptionKeySource::FillRequest(const std::string& content_id,
|
|||
drm_types->AppendString("WIDEVINE");
|
||||
request_dict.Set("drm_types", drm_types);
|
||||
|
||||
// Build key rotation fields.
|
||||
if (key_rotation_enabled_) {
|
||||
request_dict.SetInteger("first_crypto_period_index",
|
||||
first_crypto_period_index);
|
||||
request_dict.SetInteger("crypto_period_count", crypto_period_count_);
|
||||
}
|
||||
|
||||
base::JSONWriter::Write(&request_dict, request);
|
||||
}
|
||||
|
||||
|
@ -288,17 +370,40 @@ bool WidevineEncryptionKeySource::ExtractEncryptionKey(
|
|||
|
||||
const base::ListValue* tracks;
|
||||
RCHECK(license_dict->GetList("tracks", &tracks));
|
||||
RCHECK(tracks->GetSize() >= NUM_VALID_TRACK_TYPES);
|
||||
RCHECK(key_rotation_enabled_
|
||||
? tracks->GetSize() >= NUM_VALID_TRACK_TYPES * crypto_period_count_
|
||||
: tracks->GetSize() >= NUM_VALID_TRACK_TYPES);
|
||||
|
||||
int current_crypto_period_index = first_crypto_period_index_;
|
||||
|
||||
EncryptionKeyMap encryption_key_map;
|
||||
for (size_t i = 0; i < tracks->GetSize(); ++i) {
|
||||
const base::DictionaryValue* track_dict;
|
||||
RCHECK(tracks->GetDictionary(i, &track_dict));
|
||||
|
||||
if (key_rotation_enabled_) {
|
||||
int crypto_period_index;
|
||||
RCHECK(
|
||||
track_dict->GetInteger("crypto_period_index", &crypto_period_index));
|
||||
if (crypto_period_index != current_crypto_period_index) {
|
||||
if (crypto_period_index != current_crypto_period_index + 1) {
|
||||
LOG(ERROR) << "Expecting crypto period index "
|
||||
<< current_crypto_period_index << " or "
|
||||
<< current_crypto_period_index + 1 << "; Seen "
|
||||
<< crypto_period_index << " at track " << i;
|
||||
return false;
|
||||
}
|
||||
if (!PushToKeyPool(&encryption_key_map))
|
||||
return false;
|
||||
++current_crypto_period_index;
|
||||
}
|
||||
}
|
||||
|
||||
std::string track_type_str;
|
||||
RCHECK(track_dict->GetString("type", &track_type_str));
|
||||
TrackType track_type = GetTrackTypeFromString(track_type_str);
|
||||
DCHECK_NE(TRACK_TYPE_UNKNOWN, track_type);
|
||||
RCHECK(encryption_key_map_.find(track_type) == encryption_key_map_.end());
|
||||
RCHECK(encryption_key_map.find(track_type) == encryption_key_map.end());
|
||||
|
||||
scoped_ptr<EncryptionKey> encryption_key(new EncryptionKey());
|
||||
std::vector<uint8> pssh_data;
|
||||
|
@ -307,7 +412,24 @@ bool WidevineEncryptionKeySource::ExtractEncryptionKey(
|
|||
!GetPsshData(*track_dict, &pssh_data))
|
||||
return false;
|
||||
encryption_key->pssh = PsshBoxFromPsshData(pssh_data);
|
||||
encryption_key_map_[track_type] = encryption_key.release();
|
||||
encryption_key_map[track_type] = encryption_key.release();
|
||||
}
|
||||
|
||||
DCHECK(!encryption_key_map.empty());
|
||||
return PushToKeyPool(&encryption_key_map);
|
||||
}
|
||||
|
||||
bool WidevineEncryptionKeySource::PushToKeyPool(
|
||||
EncryptionKeyMap* encryption_key_map) {
|
||||
DCHECK(encryption_key_map);
|
||||
Status status =
|
||||
key_pool_.Push(scoped_refptr<RefCountedEncryptionKeyMap>(
|
||||
new RefCountedEncryptionKeyMap(encryption_key_map)),
|
||||
kInfiniteTimeout);
|
||||
encryption_key_map->clear();
|
||||
if (!status.ok()) {
|
||||
DCHECK_EQ(error::STOPPED, status.error_code());
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -11,10 +11,13 @@
|
|||
|
||||
#include "base/basictypes.h"
|
||||
#include "base/memory/scoped_ptr.h"
|
||||
#include "base/synchronization/lock.h"
|
||||
#include "media/base/closure_thread.h"
|
||||
#include "media/base/encryption_key_source.h"
|
||||
#include "media/base/producer_consumer_queue.h"
|
||||
|
||||
namespace media {
|
||||
/// A negative crypto period index disables key rotation.
|
||||
static const int kDisableKeyRotation = -1;
|
||||
|
||||
class HttpFetcher;
|
||||
class RequestSigner;
|
||||
|
@ -25,26 +28,47 @@ class WidevineEncryptionKeySource : public EncryptionKeySource {
|
|||
public:
|
||||
/// @param server_url is the Widevine common encryption server url.
|
||||
/// @param content_id the unique id identify the content to be encrypted.
|
||||
/// @param signer must not be NULL.
|
||||
/// @param signer signs the request message. It should not be NULL.
|
||||
/// @param first_crypto_period_index indicates the starting crypto period
|
||||
/// index. Set it to kDisableKeyRotation to disable key rotation.
|
||||
WidevineEncryptionKeySource(const std::string& server_url,
|
||||
const std::string& content_id,
|
||||
scoped_ptr<RequestSigner> signer);
|
||||
scoped_ptr<RequestSigner> signer,
|
||||
int first_crypto_period_index);
|
||||
virtual ~WidevineEncryptionKeySource();
|
||||
|
||||
/// EncryptionKeySource implementation override.
|
||||
/// @name EncryptionKeySource implementation overrides.
|
||||
/// @{
|
||||
virtual Status GetKey(TrackType track_type, EncryptionKey* key) OVERRIDE;
|
||||
virtual Status GetCryptoPeriodKey(uint32 crypto_period_index,
|
||||
TrackType track_type,
|
||||
EncryptionKey* key) OVERRIDE;
|
||||
/// @}
|
||||
|
||||
/// Inject an @b HttpFetcher object, mainly used for testing.
|
||||
/// @param http_fetcher points to the @b HttpFetcher object to be injected.
|
||||
void set_http_fetcher(scoped_ptr<HttpFetcher> http_fetcher);
|
||||
|
||||
private:
|
||||
typedef std::map<TrackType, EncryptionKey*> EncryptionKeyMap;
|
||||
class RefCountedEncryptionKeyMap;
|
||||
|
||||
// Internal routine for getting keys.
|
||||
Status GetKeyInternal(uint32 crypto_period_index,
|
||||
TrackType track_type,
|
||||
EncryptionKey* key);
|
||||
|
||||
// The closure task to fetch keys repeatedly.
|
||||
void FetchKeysTask();
|
||||
|
||||
// Fetch keys from server.
|
||||
Status FetchKeys();
|
||||
Status FetchKeys(uint32 first_crypto_period_index);
|
||||
|
||||
// Fill |request| with necessary fields for Widevine encryption request.
|
||||
// |request| should not be NULL.
|
||||
void FillRequest(const std::string& content_id, std::string* request);
|
||||
void FillRequest(const std::string& content_id,
|
||||
uint32 first_crypto_period_index,
|
||||
std::string* request);
|
||||
// Sign and properly format |request|.
|
||||
// |signed_request| should not be NULL. Return OK on success.
|
||||
Status SignRequest(const std::string& request, std::string* signed_request);
|
||||
|
@ -56,6 +80,8 @@ class WidevineEncryptionKeySource : public EncryptionKeySource {
|
|||
// failure is because of a transient error from the server. |transient_error|
|
||||
// should not be NULL.
|
||||
bool ExtractEncryptionKey(const std::string& response, bool* transient_error);
|
||||
// Push the keys to the key pool.
|
||||
bool PushToKeyPool(EncryptionKeyMap* encryption_key_map);
|
||||
|
||||
// The fetcher object used to fetch HTTP response from server.
|
||||
// It is initialized to a default fetcher on class initialization.
|
||||
|
@ -65,9 +91,12 @@ class WidevineEncryptionKeySource : public EncryptionKeySource {
|
|||
std::string content_id_;
|
||||
scoped_ptr<RequestSigner> signer_;
|
||||
|
||||
mutable base::Lock lock_;
|
||||
bool key_fetched_; // Protected by lock_;
|
||||
std::map<TrackType, EncryptionKey*> encryption_key_map_;
|
||||
const bool key_rotation_enabled_;
|
||||
const uint32 crypto_period_count_;
|
||||
uint32 first_crypto_period_index_;
|
||||
ClosureThread key_production_thread_;
|
||||
ProducerConsumerQueue<scoped_refptr<RefCountedEncryptionKeyMap> > key_pool_;
|
||||
Status common_encryption_request_status_;
|
||||
|
||||
DISALLOW_COPY_AND_ASSIGN(WidevineEncryptionKeySource);
|
||||
};
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include "media/base/widevine_encryption_key_source.h"
|
||||
|
||||
#include "base/base64.h"
|
||||
#include "base/strings/string_number_conversions.h"
|
||||
#include "base/strings/stringprintf.h"
|
||||
#include "media/base/http_fetcher.h"
|
||||
#include "media/base/request_signer.h"
|
||||
|
@ -89,6 +90,7 @@ std::string GetPsshDataFromPsshBox(const std::string& pssh_box) {
|
|||
|
||||
using ::testing::_;
|
||||
using ::testing::DoAll;
|
||||
using ::testing::InSequence;
|
||||
using ::testing::Return;
|
||||
using ::testing::SetArgPointee;
|
||||
|
||||
|
@ -129,9 +131,12 @@ class WidevineEncryptionKeySourceTest : public ::testing::Test {
|
|||
mock_http_fetcher_(new MockHttpFetcher()) {}
|
||||
|
||||
protected:
|
||||
void CreateWidevineEncryptionKeySource() {
|
||||
void CreateWidevineEncryptionKeySource(int first_crypto_period_index) {
|
||||
widevine_encryption_key_source_.reset(new WidevineEncryptionKeySource(
|
||||
kServerUrl, kContentId, mock_request_signer_.PassAs<RequestSigner>()));
|
||||
kServerUrl,
|
||||
kContentId,
|
||||
mock_request_signer_.PassAs<RequestSigner>(),
|
||||
first_crypto_period_index));
|
||||
widevine_encryption_key_source_->set_http_fetcher(
|
||||
mock_http_fetcher_.PassAs<HttpFetcher>());
|
||||
}
|
||||
|
@ -155,11 +160,11 @@ TEST_F(WidevineEncryptionKeySourceTest, GetTrackTypeFromString) {
|
|||
EncryptionKeySource::GetTrackTypeFromString("FOO"));
|
||||
}
|
||||
|
||||
TEST_F(WidevineEncryptionKeySourceTest, GeneratureSignatureFailure) {
|
||||
TEST_F(WidevineEncryptionKeySourceTest, GenerateSignatureFailure) {
|
||||
EXPECT_CALL(*mock_request_signer_, GenerateSignature(_, _))
|
||||
.WillOnce(Return(false));
|
||||
|
||||
CreateWidevineEncryptionKeySource();
|
||||
CreateWidevineEncryptionKeySource(kDisableKeyRotation);
|
||||
EncryptionKey encryption_key;
|
||||
ASSERT_EQ(Status(error::INTERNAL_ERROR, "Signature generation failed."),
|
||||
widevine_encryption_key_source_->GetKey(
|
||||
|
@ -183,7 +188,7 @@ TEST_F(WidevineEncryptionKeySourceTest, HttpPostFailure) {
|
|||
EXPECT_CALL(*mock_http_fetcher_, Post(kServerUrl, expected_post_data, _))
|
||||
.WillOnce(Return(kMockStatus));
|
||||
|
||||
CreateWidevineEncryptionKeySource();
|
||||
CreateWidevineEncryptionKeySource(kDisableKeyRotation);
|
||||
EncryptionKey encryption_key;
|
||||
ASSERT_EQ(kMockStatus,
|
||||
widevine_encryption_key_source_->GetKey(
|
||||
|
@ -194,13 +199,13 @@ TEST_F(WidevineEncryptionKeySourceTest, LicenseStatusOK) {
|
|||
EXPECT_CALL(*mock_request_signer_, GenerateSignature(_, _))
|
||||
.WillOnce(Return(true));
|
||||
|
||||
std::string expected_response = base::StringPrintf(
|
||||
std::string mock_response = base::StringPrintf(
|
||||
kHttpResponseFormat, Base64Encode(GenerateMockLicenseResponse()).c_str());
|
||||
|
||||
EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _))
|
||||
.WillOnce(DoAll(SetArgPointee<2>(expected_response), Return(Status::OK)));
|
||||
.WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK)));
|
||||
|
||||
CreateWidevineEncryptionKeySource();
|
||||
CreateWidevineEncryptionKeySource(kDisableKeyRotation);
|
||||
|
||||
EncryptionKey encryption_key;
|
||||
const std::string kTrackTypes[] = {"SD", "HD", "AUDIO"};
|
||||
|
@ -221,7 +226,7 @@ TEST_F(WidevineEncryptionKeySourceTest, RetryOnTransientError) {
|
|||
|
||||
std::string mock_license_status = base::StringPrintf(
|
||||
kLicenseResponseFormat, kLicenseStatusTransientError, "");
|
||||
std::string expected_response = base::StringPrintf(
|
||||
std::string mock_response = base::StringPrintf(
|
||||
kHttpResponseFormat, Base64Encode(mock_license_status).c_str());
|
||||
|
||||
std::string expected_retried_response = base::StringPrintf(
|
||||
|
@ -229,11 +234,11 @@ TEST_F(WidevineEncryptionKeySourceTest, RetryOnTransientError) {
|
|||
|
||||
// Retry is expected on transient error.
|
||||
EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _))
|
||||
.WillOnce(DoAll(SetArgPointee<2>(expected_response), Return(Status::OK)))
|
||||
.WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK)))
|
||||
.WillOnce(DoAll(SetArgPointee<2>(expected_retried_response),
|
||||
Return(Status::OK)));
|
||||
|
||||
CreateWidevineEncryptionKeySource();
|
||||
CreateWidevineEncryptionKeySource(kDisableKeyRotation);
|
||||
EncryptionKey encryption_key;
|
||||
ASSERT_OK(widevine_encryption_key_source_->GetKey(
|
||||
EncryptionKeySource::TRACK_TYPE_SD, &encryption_key));
|
||||
|
@ -255,13 +260,131 @@ TEST_F(WidevineEncryptionKeySourceTest, NoRetryOnUnknownError) {
|
|||
EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _))
|
||||
.WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK)));
|
||||
|
||||
CreateWidevineEncryptionKeySource();
|
||||
CreateWidevineEncryptionKeySource(kDisableKeyRotation);
|
||||
EncryptionKey encryption_key;
|
||||
ASSERT_EQ(
|
||||
error::SERVER_ERROR,
|
||||
widevine_encryption_key_source_->GetKey(
|
||||
Status status = widevine_encryption_key_source_->GetKey(
|
||||
EncryptionKeySource::TRACK_TYPE_SD, &encryption_key);
|
||||
ASSERT_EQ(error::SERVER_ERROR, status.error_code());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
const char kCryptoPeriodRequestMessageFormat[] =
|
||||
"{\"content_id\":\"%s\",\"crypto_period_count\":%u,\"drm_types\":["
|
||||
"\"WIDEVINE\"],\"first_crypto_period_index\":%u,\"policy\":\"\","
|
||||
"\"tracks\":[{\"type\":\"SD\"},{\"type\":\"HD\"},{\"type\":\"AUDIO\"}]}";
|
||||
|
||||
const char kCryptoPeriodTrackFormat[] =
|
||||
"{\"type\":\"%s\",\"key_id\":\"\",\"key\":"
|
||||
"\"%s\",\"pssh\":[{\"drm_type\":\"WIDEVINE\",\"data\":\"\"}], "
|
||||
"\"crypto_period_index\":%u}";
|
||||
|
||||
std::string GetMockKey(const std::string& track_type, uint32 index) {
|
||||
return "MockKey" + track_type + "@" + base::UintToString(index);
|
||||
}
|
||||
|
||||
std::string GenerateMockLicenseResponse(uint32 initial_crypto_period_index,
|
||||
uint32 crypto_period_count) {
|
||||
const std::string kTrackTypes[] = {"SD", "HD", "AUDIO"};
|
||||
std::string tracks;
|
||||
for (uint32 index = initial_crypto_period_index;
|
||||
index < initial_crypto_period_index + crypto_period_count;
|
||||
++index) {
|
||||
for (size_t i = 0; i < 3; ++i) {
|
||||
if (!tracks.empty())
|
||||
tracks += ",";
|
||||
tracks += base::StringPrintf(
|
||||
kCryptoPeriodTrackFormat,
|
||||
kTrackTypes[i].c_str(),
|
||||
Base64Encode(GetMockKey(kTrackTypes[i], index)).c_str(),
|
||||
index);
|
||||
}
|
||||
}
|
||||
return base::StringPrintf(kLicenseResponseFormat, "OK", tracks.c_str());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TEST_F(WidevineEncryptionKeySourceTest, KeyRotationTest) {
|
||||
const uint32 kFirstCryptoPeriodIndex = 8;
|
||||
const uint32 kCryptoPeriodCount = 10;
|
||||
// Array of indexes to be checked.
|
||||
const uint32 kCryptoPeriodIndexes[] = {kFirstCryptoPeriodIndex, 17, 37,
|
||||
38, 36, 39};
|
||||
// Derived from kCryptoPeriodIndexes: ceiling((39 - 8 ) / 10).
|
||||
const uint32 kCryptoIterations = 4;
|
||||
|
||||
// Generate expectations in sequence.
|
||||
InSequence dummy;
|
||||
for (uint32 i = 0; i < kCryptoIterations; ++i) {
|
||||
uint32 first_crypto_period_index =
|
||||
kFirstCryptoPeriodIndex + i * kCryptoPeriodCount;
|
||||
std::string expected_message =
|
||||
base::StringPrintf(kCryptoPeriodRequestMessageFormat,
|
||||
Base64Encode(kContentId).c_str(),
|
||||
kCryptoPeriodCount,
|
||||
first_crypto_period_index);
|
||||
EXPECT_CALL(*mock_request_signer_, GenerateSignature(expected_message, _))
|
||||
.WillOnce(DoAll(SetArgPointee<1>(kMockSignature), Return(true)));
|
||||
|
||||
std::string mock_response = base::StringPrintf(
|
||||
kHttpResponseFormat,
|
||||
Base64Encode(GenerateMockLicenseResponse(first_crypto_period_index,
|
||||
kCryptoPeriodCount)).c_str());
|
||||
EXPECT_CALL(*mock_http_fetcher_, Post(_, _, _))
|
||||
.WillOnce(DoAll(SetArgPointee<2>(mock_response), Return(Status::OK)));
|
||||
}
|
||||
|
||||
CreateWidevineEncryptionKeySource(kFirstCryptoPeriodIndex);
|
||||
|
||||
EncryptionKey encryption_key;
|
||||
|
||||
// Index before kFirstCryptoPeriodIndex is invalid.
|
||||
Status status = widevine_encryption_key_source_->GetCryptoPeriodKey(
|
||||
kFirstCryptoPeriodIndex - 1,
|
||||
EncryptionKeySource::TRACK_TYPE_SD,
|
||||
&encryption_key).error_code());
|
||||
&encryption_key);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, status.error_code());
|
||||
|
||||
for (size_t i = 0; i < arraysize(kCryptoPeriodIndexes); ++i) {
|
||||
const std::string kTrackTypes[] = {"SD", "HD", "AUDIO"};
|
||||
for (size_t j = 0; j < 3; ++j) {
|
||||
ASSERT_OK(widevine_encryption_key_source_->GetCryptoPeriodKey(
|
||||
kCryptoPeriodIndexes[i],
|
||||
EncryptionKeySource::GetTrackTypeFromString(kTrackTypes[j]),
|
||||
&encryption_key));
|
||||
EXPECT_EQ(GetMockKey(kTrackTypes[j], kCryptoPeriodIndexes[i]),
|
||||
ToString(encryption_key.key));
|
||||
}
|
||||
}
|
||||
|
||||
// The old crypto period indexes should have been garbage collected.
|
||||
status = widevine_encryption_key_source_->GetCryptoPeriodKey(
|
||||
kFirstCryptoPeriodIndex,
|
||||
EncryptionKeySource::TRACK_TYPE_SD,
|
||||
&encryption_key);
|
||||
EXPECT_EQ(error::INVALID_ARGUMENT, status.error_code());
|
||||
}
|
||||
|
||||
class WidevineEncryptionKeySourceDeathTest
|
||||
: public WidevineEncryptionKeySourceTest {};
|
||||
|
||||
TEST_F(WidevineEncryptionKeySourceDeathTest,
|
||||
GetCryptoPeriodKeyOnNonKeyRotationSource) {
|
||||
CreateWidevineEncryptionKeySource(kDisableKeyRotation);
|
||||
EncryptionKey encryption_key;
|
||||
EXPECT_DEBUG_DEATH(
|
||||
widevine_encryption_key_source_->GetCryptoPeriodKey(
|
||||
0, EncryptionKeySource::TRACK_TYPE_SD, &encryption_key),
|
||||
"");
|
||||
}
|
||||
|
||||
TEST_F(WidevineEncryptionKeySourceDeathTest, GetKeyOnKeyRotationSource) {
|
||||
CreateWidevineEncryptionKeySource(0);
|
||||
EncryptionKey encryption_key;
|
||||
EXPECT_DEBUG_DEATH(widevine_encryption_key_source_->GetKey(
|
||||
EncryptionKeySource::TRACK_TYPE_SD, &encryption_key),
|
||||
"");
|
||||
}
|
||||
|
||||
} // namespace media
|
||||
|
|
Loading…
Reference in New Issue