DASH Media Packaging SDK
 All Classes Namespaces Functions Variables Typedefs Enumerator
widevine_key_source.cc
1 // Copyright 2014 Google Inc. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include "packager/media/base/widevine_key_source.h"
8 
9 #include "packager/base/base64.h"
10 #include "packager/base/bind.h"
11 #include "packager/base/json/json_reader.h"
12 #include "packager/base/json/json_writer.h"
13 #include "packager/base/memory/ref_counted.h"
14 #include "packager/base/stl_util.h"
15 #include "packager/media/base/http_key_fetcher.h"
16 #include "packager/media/base/producer_consumer_queue.h"
17 #include "packager/media/base/request_signer.h"
18 
19 #define RCHECK(x) \
20  do { \
21  if (!(x)) { \
22  LOG(ERROR) << "Failure while processing: " << #x; \
23  return false; \
24  } \
25  } while (0)
26 
27 namespace edash_packager {
28 namespace {
29 
30 const bool kEnableKeyRotation = true;
31 
32 const char kLicenseStatusOK[] = "OK";
33 // Server may return INTERNAL_ERROR intermittently, which is a transient error
34 // and the next client request may succeed without problem.
35 const char kLicenseStatusTransientError[] = "INTERNAL_ERROR";
36 
37 // Number of times to retry requesting keys in case of a transient error from
38 // the server.
39 const int kNumTransientErrorRetries = 5;
40 const int kFirstRetryDelayMilliseconds = 1000;
41 
42 // Default crypto period count, which is the number of keys to fetch on every
43 // key rotation enabled request.
44 const int kDefaultCryptoPeriodCount = 10;
45 const int kGetKeyTimeoutInSeconds = 5 * 60; // 5 minutes.
46 const int kKeyFetchTimeoutInSeconds = 60; // 1 minute.
47 
48 bool Base64StringToBytes(const std::string& base64_string,
49  std::vector<uint8_t>* bytes) {
50  DCHECK(bytes);
51  std::string str;
52  if (!base::Base64Decode(base64_string, &str))
53  return false;
54  bytes->assign(str.begin(), str.end());
55  return true;
56 }
57 
58 void BytesToBase64String(const std::vector<uint8_t>& bytes,
59  std::string* base64_string) {
60  DCHECK(base64_string);
61  base::Base64Encode(base::StringPiece(reinterpret_cast<const char*>
62  (bytes.data()), bytes.size()),
63  base64_string);
64 }
65 
66 bool GetKeyFromTrack(const base::DictionaryValue& track_dict,
67  std::vector<uint8_t>* key) {
68  DCHECK(key);
69  std::string key_base64_string;
70  RCHECK(track_dict.GetString("key", &key_base64_string));
71  VLOG(2) << "Key:" << key_base64_string;
72  RCHECK(Base64StringToBytes(key_base64_string, key));
73  return true;
74 }
75 
76 bool GetKeyIdFromTrack(const base::DictionaryValue& track_dict,
77  std::vector<uint8_t>* key_id) {
78  DCHECK(key_id);
79  std::string key_id_base64_string;
80  RCHECK(track_dict.GetString("key_id", &key_id_base64_string));
81  VLOG(2) << "Keyid:" << key_id_base64_string;
82  RCHECK(Base64StringToBytes(key_id_base64_string, key_id));
83  return true;
84 }
85 
86 bool GetPsshDataFromTrack(const base::DictionaryValue& track_dict,
87  std::vector<uint8_t>* pssh_data) {
88  DCHECK(pssh_data);
89 
90  const base::ListValue* pssh_list;
91  RCHECK(track_dict.GetList("pssh", &pssh_list));
92  // Invariant check. We don't want to crash in release mode if possible.
93  // The following code handles it gracefully if GetSize() does not return 1.
94  DCHECK_EQ(1u, pssh_list->GetSize());
95 
96  const base::DictionaryValue* pssh_dict;
97  RCHECK(pssh_list->GetDictionary(0, &pssh_dict));
98  std::string drm_type;
99  RCHECK(pssh_dict->GetString("drm_type", &drm_type));
100  if (drm_type != "WIDEVINE") {
101  LOG(ERROR) << "Expecting drm_type 'WIDEVINE', get '" << drm_type << "'.";
102  return false;
103  }
104  std::string pssh_data_base64_string;
105  RCHECK(pssh_dict->GetString("data", &pssh_data_base64_string));
106 
107  VLOG(2) << "Pssh Data:" << pssh_data_base64_string;
108  RCHECK(Base64StringToBytes(pssh_data_base64_string, pssh_data));
109  return true;
110 }
111 
112 } // namespace
113 
114 namespace media {
115 
116 // A ref counted wrapper for EncryptionKeyMap.
117 class WidevineKeySource::RefCountedEncryptionKeyMap
118  : public base::RefCountedThreadSafe<RefCountedEncryptionKeyMap> {
119  public:
120  explicit RefCountedEncryptionKeyMap(EncryptionKeyMap* encryption_key_map) {
121  DCHECK(encryption_key_map);
122  encryption_key_map_.swap(*encryption_key_map);
123  }
124 
125  std::map<KeySource::TrackType, EncryptionKey*>& map() {
126  return encryption_key_map_;
127  }
128 
129  private:
130  friend class base::RefCountedThreadSafe<RefCountedEncryptionKeyMap>;
131 
132  ~RefCountedEncryptionKeyMap() { STLDeleteValues(&encryption_key_map_); }
133 
134  EncryptionKeyMap encryption_key_map_;
135 
136  DISALLOW_COPY_AND_ASSIGN(RefCountedEncryptionKeyMap);
137 };
138 
139 WidevineKeySource::WidevineKeySource(const std::string& server_url)
140  : key_production_thread_("KeyProductionThread",
141  base::Bind(&WidevineKeySource::FetchKeysTask,
142  base::Unretained(this))),
143  key_fetcher_(new HttpKeyFetcher(kKeyFetchTimeoutInSeconds)),
144  server_url_(server_url),
145  crypto_period_count_(kDefaultCryptoPeriodCount),
146  key_production_started_(false),
147  start_key_production_(false, false),
148  first_crypto_period_index_(0) {
149  key_production_thread_.Start();
150 }
151 
152 WidevineKeySource::~WidevineKeySource() {
153  if (key_pool_)
154  key_pool_->Stop();
155  if (key_production_thread_.HasBeenStarted()) {
156  // Signal the production thread to start key production if it is not
157  // signaled yet so the thread can be joined.
158  start_key_production_.Signal();
159  key_production_thread_.Join();
160  }
161  STLDeleteValues(&encryption_key_map_);
162 }
163 
164 Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& content_id,
165  const std::string& policy) {
166  base::AutoLock scoped_lock(lock_);
167  request_dict_.Clear();
168  std::string content_id_base64_string;
169  BytesToBase64String(content_id, &content_id_base64_string);
170  request_dict_.SetString("content_id", content_id_base64_string);
171  request_dict_.SetString("policy", policy);
172  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
173 }
174 
175 Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& pssh_data) {
176  base::AutoLock scoped_lock(lock_);
177  request_dict_.Clear();
178  std::string pssh_data_base64_string;
179  BytesToBase64String(pssh_data, &pssh_data_base64_string);
180  request_dict_.SetString("pssh_data", pssh_data_base64_string);
181  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
182 }
183 
185  base::AutoLock scoped_lock(lock_);
186  request_dict_.Clear();
187  // Javascript/JSON does not support int64_t or unsigned numbers. Use double
188  // instead as 32-bit integer can be lossless represented using double.
189  request_dict_.SetDouble("asset_id", asset_id);
190  return FetchKeysInternal(!kEnableKeyRotation, 0, true);
191 }
192 
193 Status WidevineKeySource::GetKey(TrackType track_type, EncryptionKey* key) {
194  DCHECK(key);
195  if (encryption_key_map_.find(track_type) == encryption_key_map_.end()) {
196  return Status(error::INTERNAL_ERROR,
197  "Cannot find key of type " + TrackTypeToString(track_type));
198  }
199  *key = *encryption_key_map_[track_type];
200  return Status::OK;
201 }
202 
203 Status WidevineKeySource::GetKey(const std::vector<uint8_t>& key_id,
204  EncryptionKey* key) {
205  DCHECK(key);
206  for (std::map<TrackType, EncryptionKey*>::iterator iter =
207  encryption_key_map_.begin();
208  iter != encryption_key_map_.end();
209  ++iter) {
210  if (iter->second->key_id == key_id) {
211  *key = *iter->second;
212  return Status::OK;
213  }
214  }
215  return Status(error::INTERNAL_ERROR,
216  "Cannot find key with specified key ID");
217 }
218 
219 Status WidevineKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
220  TrackType track_type,
221  EncryptionKey* key) {
222  DCHECK(key_production_thread_.HasBeenStarted());
223  // TODO(kqyang): This is not elegant. Consider refactoring later.
224  {
225  base::AutoLock scoped_lock(lock_);
226  if (!key_production_started_) {
227  // Another client may have a slightly smaller starting crypto period
228  // index. Set the initial value to account for that.
229  first_crypto_period_index_ =
230  crypto_period_index ? crypto_period_index - 1 : 0;
231  DCHECK(!key_pool_);
232  key_pool_.reset(new EncryptionKeyQueue(crypto_period_count_,
233  first_crypto_period_index_));
234  start_key_production_.Signal();
235  key_production_started_ = true;
236  }
237  }
238  return GetKeyInternal(crypto_period_index, track_type, key);
239 }
240 
241 std::string WidevineKeySource::UUID() {
242  return "edef8ba9-79d6-4ace-a3c8-27dcd51d21ed";
243 }
244 
245 void WidevineKeySource::set_signer(scoped_ptr<RequestSigner> signer) {
246  signer_ = signer.Pass();
247 }
248 
249 void WidevineKeySource::set_key_fetcher(scoped_ptr<KeyFetcher> key_fetcher) {
250  key_fetcher_ = key_fetcher.Pass();
251 }
252 
253 Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
254  TrackType track_type,
255  EncryptionKey* key) {
256  DCHECK(key_pool_);
257  DCHECK(key);
258  DCHECK_LE(track_type, NUM_VALID_TRACK_TYPES);
259  DCHECK_NE(track_type, TRACK_TYPE_UNKNOWN);
260 
261  scoped_refptr<RefCountedEncryptionKeyMap> ref_counted_encryption_key_map;
262  Status status =
263  key_pool_->Peek(crypto_period_index, &ref_counted_encryption_key_map,
264  kGetKeyTimeoutInSeconds * 1000);
265  if (!status.ok()) {
266  if (status.error_code() == error::STOPPED) {
267  CHECK(!common_encryption_request_status_.ok());
268  return common_encryption_request_status_;
269  }
270  return status;
271  }
272 
273  EncryptionKeyMap& encryption_key_map = ref_counted_encryption_key_map->map();
274  if (encryption_key_map.find(track_type) == encryption_key_map.end()) {
275  return Status(error::INTERNAL_ERROR,
276  "Cannot find key of type " + TrackTypeToString(track_type));
277  }
278  *key = *encryption_key_map[track_type];
279  return Status::OK;
280 }
281 
282 void WidevineKeySource::FetchKeysTask() {
283  // Wait until key production is signaled.
284  start_key_production_.Wait();
285  if (!key_pool_ || key_pool_->Stopped())
286  return;
287 
288  Status status = FetchKeysInternal(kEnableKeyRotation,
289  first_crypto_period_index_,
290  false);
291  while (status.ok()) {
292  first_crypto_period_index_ += crypto_period_count_;
293  status = FetchKeysInternal(kEnableKeyRotation,
294  first_crypto_period_index_,
295  false);
296  }
297  common_encryption_request_status_ = status;
298  key_pool_->Stop();
299 }
300 
301 Status WidevineKeySource::FetchKeysInternal(bool enable_key_rotation,
302  uint32_t first_crypto_period_index,
303  bool widevine_classic) {
304  std::string request;
305  FillRequest(enable_key_rotation,
306  first_crypto_period_index,
307  &request);
308 
309  std::string message;
310  Status status = GenerateKeyMessage(request, &message);
311  if (!status.ok())
312  return status;
313  VLOG(1) << "Message: " << message;
314 
315  std::string raw_response;
316  int64_t sleep_duration = kFirstRetryDelayMilliseconds;
317 
318  // Perform client side retries if seeing server transient error to workaround
319  // server limitation.
320  for (int i = 0; i < kNumTransientErrorRetries; ++i) {
321  status = key_fetcher_->FetchKeys(server_url_, message, &raw_response);
322  if (status.ok()) {
323  VLOG(1) << "Retry [" << i << "] Response:" << raw_response;
324 
325  std::string response;
326  if (!DecodeResponse(raw_response, &response)) {
327  return Status(error::SERVER_ERROR,
328  "Failed to decode response '" + raw_response + "'.");
329  }
330 
331  bool transient_error = false;
332  if (ExtractEncryptionKey(enable_key_rotation,
333  widevine_classic,
334  response,
335  &transient_error))
336  return Status::OK;
337 
338  if (!transient_error) {
339  return Status(
340  error::SERVER_ERROR,
341  "Failed to extract encryption key from '" + response + "'.");
342  }
343  } else if (status.error_code() != error::TIME_OUT) {
344  return status;
345  }
346 
347  // Exponential backoff.
348  if (i != kNumTransientErrorRetries - 1) {
349  base::PlatformThread::Sleep(
350  base::TimeDelta::FromMilliseconds(sleep_duration));
351  sleep_duration *= 2;
352  }
353  }
354  return Status(error::SERVER_ERROR,
355  "Failed to recover from server internal error.");
356 }
357 
358 void WidevineKeySource::FillRequest(bool enable_key_rotation,
359  uint32_t first_crypto_period_index,
360  std::string* request) {
361  DCHECK(request);
362  DCHECK(!request_dict_.empty());
363 
364  // Build tracks.
365  base::ListValue* tracks = new base::ListValue();
366 
367  base::DictionaryValue* track_sd = new base::DictionaryValue();
368  track_sd->SetString("type", "SD");
369  tracks->Append(track_sd);
370  base::DictionaryValue* track_hd = new base::DictionaryValue();
371  track_hd->SetString("type", "HD");
372  tracks->Append(track_hd);
373  base::DictionaryValue* track_audio = new base::DictionaryValue();
374  track_audio->SetString("type", "AUDIO");
375  tracks->Append(track_audio);
376 
377  request_dict_.Set("tracks", tracks);
378 
379  // Build DRM types.
380  base::ListValue* drm_types = new base::ListValue();
381  drm_types->AppendString("WIDEVINE");
382  request_dict_.Set("drm_types", drm_types);
383 
384  // Build key rotation fields.
385  if (enable_key_rotation) {
386  // Javascript/JSON does not support int64_t or unsigned numbers. Use double
387  // instead as 32-bit integer can be lossless represented using double.
388  request_dict_.SetDouble("first_crypto_period_index",
389  first_crypto_period_index);
390  request_dict_.SetInteger("crypto_period_count", crypto_period_count_);
391  }
392 
393  base::JSONWriter::WriteWithOptions(
394  request_dict_,
395  // Write doubles that have no fractional part as a normal integer, i.e.
396  // without using exponential notation or appending a '.0'.
397  base::JSONWriter::OPTIONS_OMIT_DOUBLE_TYPE_PRESERVATION, request);
398 }
399 
400 Status WidevineKeySource::GenerateKeyMessage(const std::string& request,
401  std::string* message) {
402  DCHECK(message);
403 
404  std::string request_base64_string;
405  base::Base64Encode(request, &request_base64_string);
406 
407  base::DictionaryValue request_dict;
408  request_dict.SetString("request", request_base64_string);
409 
410  // Sign the request.
411  if (signer_) {
412  std::string signature;
413  if (!signer_->GenerateSignature(request, &signature))
414  return Status(error::INTERNAL_ERROR, "Signature generation failed.");
415 
416  std::string signature_base64_string;
417  base::Base64Encode(signature, &signature_base64_string);
418 
419  request_dict.SetString("signature", signature_base64_string);
420  request_dict.SetString("signer", signer_->signer_name());
421  }
422 
423  base::JSONWriter::Write(request_dict, message);
424  return Status::OK;
425 }
426 
427 bool WidevineKeySource::DecodeResponse(
428  const std::string& raw_response,
429  std::string* response) {
430  DCHECK(response);
431 
432  // Extract base64 formatted response from JSON formatted raw response.
433  scoped_ptr<base::Value> root(base::JSONReader::Read(raw_response));
434  if (!root) {
435  LOG(ERROR) << "'" << raw_response << "' is not in JSON format.";
436  return false;
437  }
438  const base::DictionaryValue* response_dict = NULL;
439  RCHECK(root->GetAsDictionary(&response_dict));
440 
441  std::string response_base64_string;
442  RCHECK(response_dict->GetString("response", &response_base64_string));
443  RCHECK(base::Base64Decode(response_base64_string, response));
444  return true;
445 }
446 
447 bool WidevineKeySource::ExtractEncryptionKey(
448  bool enable_key_rotation,
449  bool widevine_classic,
450  const std::string& response,
451  bool* transient_error) {
452  DCHECK(transient_error);
453  *transient_error = false;
454 
455  scoped_ptr<base::Value> root(base::JSONReader::Read(response));
456  if (!root) {
457  LOG(ERROR) << "'" << response << "' is not in JSON format.";
458  return false;
459  }
460 
461  const base::DictionaryValue* license_dict = NULL;
462  RCHECK(root->GetAsDictionary(&license_dict));
463 
464  std::string license_status;
465  RCHECK(license_dict->GetString("status", &license_status));
466  if (license_status != kLicenseStatusOK) {
467  LOG(ERROR) << "Received non-OK license response: " << response;
468  *transient_error = (license_status == kLicenseStatusTransientError);
469  return false;
470  }
471 
472  const base::ListValue* tracks;
473  RCHECK(license_dict->GetList("tracks", &tracks));
474  // Should have at least one track per crypto_period.
475  RCHECK(enable_key_rotation ? tracks->GetSize() >= 1 * crypto_period_count_
476  : tracks->GetSize() >= 1);
477 
478  int current_crypto_period_index = first_crypto_period_index_;
479 
480  EncryptionKeyMap encryption_key_map;
481  for (size_t i = 0; i < tracks->GetSize(); ++i) {
482  const base::DictionaryValue* track_dict;
483  RCHECK(tracks->GetDictionary(i, &track_dict));
484 
485  if (enable_key_rotation) {
486  int crypto_period_index;
487  RCHECK(
488  track_dict->GetInteger("crypto_period_index", &crypto_period_index));
489  if (crypto_period_index != current_crypto_period_index) {
490  if (crypto_period_index != current_crypto_period_index + 1) {
491  LOG(ERROR) << "Expecting crypto period index "
492  << current_crypto_period_index << " or "
493  << current_crypto_period_index + 1 << "; Seen "
494  << crypto_period_index << " at track " << i;
495  return false;
496  }
497  if (!PushToKeyPool(&encryption_key_map))
498  return false;
499  ++current_crypto_period_index;
500  }
501  }
502 
503  std::string track_type_str;
504  RCHECK(track_dict->GetString("type", &track_type_str));
505  TrackType track_type = GetTrackTypeFromString(track_type_str);
506  DCHECK_NE(TRACK_TYPE_UNKNOWN, track_type);
507  RCHECK(encryption_key_map.find(track_type) == encryption_key_map.end());
508 
509  scoped_ptr<EncryptionKey> encryption_key(new EncryptionKey());
510 
511  if (!GetKeyFromTrack(*track_dict, &encryption_key->key))
512  return false;
513 
514  // Get key ID and PSSH data for CENC content only.
515  if (!widevine_classic) {
516  if (!GetKeyIdFromTrack(*track_dict, &encryption_key->key_id))
517  return false;
518 
519  std::vector<uint8_t> pssh_data;
520  if (!GetPsshDataFromTrack(*track_dict, &pssh_data))
521  return false;
522  encryption_key->pssh = PsshBoxFromPsshData(pssh_data);
523  }
524  encryption_key_map[track_type] = encryption_key.release();
525  }
526 
527  DCHECK(!encryption_key_map.empty());
528  if (!enable_key_rotation) {
529  encryption_key_map_ = encryption_key_map;
530  return true;
531  }
532  return PushToKeyPool(&encryption_key_map);
533 }
534 
535 bool WidevineKeySource::PushToKeyPool(
536  EncryptionKeyMap* encryption_key_map) {
537  DCHECK(key_pool_);
538  DCHECK(encryption_key_map);
539  Status status =
540  key_pool_->Push(scoped_refptr<RefCountedEncryptionKeyMap>(
541  new RefCountedEncryptionKeyMap(encryption_key_map)),
542  kInfiniteTimeout);
543  encryption_key_map->clear();
544  if (!status.ok()) {
545  DCHECK_EQ(error::STOPPED, status.error_code());
546  return false;
547  }
548  return true;
549 }
550 
551 } // namespace media
552 } // namespace edash_packager
WidevineKeySource(const std::string &server_url)
void set_signer(scoped_ptr< RequestSigner > signer)
Status GetKey(TrackType track_type, EncryptionKey *key) override
void set_key_fetcher(scoped_ptr< KeyFetcher > key_fetcher)
Status FetchKeys(const std::vector< uint8_t > &content_id, const std::string &policy) override
static std::vector< uint8_t > PsshBoxFromPsshData(const std::vector< uint8_t > &pssh_data)
Definition: key_source.cc:164
Status GetCryptoPeriodKey(uint32_t crypto_period_index, TrackType track_type, EncryptionKey *key) override
static TrackType GetTrackTypeFromString(const std::string &track_type_string)
Convert string representation of track type to enum representation.
Definition: key_source.cc:136
static std::string TrackTypeToString(TrackType track_type)
Convert TrackType to string.
Definition: key_source.cc:150