DASH Media Packaging SDK
 All Classes Namespaces Functions Variables Typedefs Enumerator
widevine_key_source.cc
1 // Copyright 2014 Google Inc. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include "packager/media/base/widevine_key_source.h"
8 
9 #include "packager/base/base64.h"
10 #include "packager/base/bind.h"
11 #include "packager/base/json/json_reader.h"
12 #include "packager/base/json/json_writer.h"
13 #include "packager/base/memory/ref_counted.h"
14 #include "packager/base/stl_util.h"
15 #include "packager/media/base/http_key_fetcher.h"
16 #include "packager/media/base/producer_consumer_queue.h"
17 #include "packager/media/base/protection_system_specific_info.h"
18 #include "packager/media/base/request_signer.h"
19 #include "packager/media/base/widevine_pssh_data.pb.h"
20 
21 #define RCHECK(x) \
22  do { \
23  if (!(x)) { \
24  LOG(ERROR) << "Failure while processing: " << #x; \
25  return false; \
26  } \
27  } while (0)
28 
29 namespace edash_packager {
30 namespace {
31 
32 const bool kEnableKeyRotation = true;
33 
34 const char kLicenseStatusOK[] = "OK";
35 // Server may return INTERNAL_ERROR intermittently, which is a transient error
36 // and the next client request may succeed without problem.
37 const char kLicenseStatusTransientError[] = "INTERNAL_ERROR";
38 
39 // Number of times to retry requesting keys in case of a transient error from
40 // the server.
41 const int kNumTransientErrorRetries = 5;
42 const int kFirstRetryDelayMilliseconds = 1000;
43 
44 // Default crypto period count, which is the number of keys to fetch on every
45 // key rotation enabled request.
46 const int kDefaultCryptoPeriodCount = 10;
47 const int kGetKeyTimeoutInSeconds = 5 * 60; // 5 minutes.
48 const int kKeyFetchTimeoutInSeconds = 60; // 1 minute.
49 
50 const uint8_t kWidevineSystemId[] = {0xed, 0xef, 0x8b, 0xa9, 0x79, 0xd6,
51  0x4a, 0xce, 0xa3, 0xc8, 0x27, 0xdc,
52  0xd5, 0x1d, 0x21, 0xed};
53 
54 bool Base64StringToBytes(const std::string& base64_string,
55  std::vector<uint8_t>* bytes) {
56  DCHECK(bytes);
57  std::string str;
58  if (!base::Base64Decode(base64_string, &str))
59  return false;
60  bytes->assign(str.begin(), str.end());
61  return true;
62 }
63 
64 void BytesToBase64String(const std::vector<uint8_t>& bytes,
65  std::string* base64_string) {
66  DCHECK(base64_string);
67  base::Base64Encode(base::StringPiece(reinterpret_cast<const char*>
68  (bytes.data()), bytes.size()),
69  base64_string);
70 }
71 
72 bool GetKeyFromTrack(const base::DictionaryValue& track_dict,
73  std::vector<uint8_t>* key) {
74  DCHECK(key);
75  std::string key_base64_string;
76  RCHECK(track_dict.GetString("key", &key_base64_string));
77  VLOG(2) << "Key:" << key_base64_string;
78  RCHECK(Base64StringToBytes(key_base64_string, key));
79  return true;
80 }
81 
82 bool GetKeyIdFromTrack(const base::DictionaryValue& track_dict,
83  std::vector<uint8_t>* key_id) {
84  DCHECK(key_id);
85  std::string key_id_base64_string;
86  RCHECK(track_dict.GetString("key_id", &key_id_base64_string));
87  VLOG(2) << "Keyid:" << key_id_base64_string;
88  RCHECK(Base64StringToBytes(key_id_base64_string, key_id));
89  return true;
90 }
91 
92 bool GetPsshDataFromTrack(const base::DictionaryValue& track_dict,
93  std::vector<uint8_t>* pssh_data) {
94  DCHECK(pssh_data);
95 
96  const base::ListValue* pssh_list;
97  RCHECK(track_dict.GetList("pssh", &pssh_list));
98  // Invariant check. We don't want to crash in release mode if possible.
99  // The following code handles it gracefully if GetSize() does not return 1.
100  DCHECK_EQ(1u, pssh_list->GetSize());
101 
102  const base::DictionaryValue* pssh_dict;
103  RCHECK(pssh_list->GetDictionary(0, &pssh_dict));
104  std::string drm_type;
105  RCHECK(pssh_dict->GetString("drm_type", &drm_type));
106  if (drm_type != "WIDEVINE") {
107  LOG(ERROR) << "Expecting drm_type 'WIDEVINE', get '" << drm_type << "'.";
108  return false;
109  }
110  std::string pssh_data_base64_string;
111  RCHECK(pssh_dict->GetString("data", &pssh_data_base64_string));
112 
113  VLOG(2) << "Pssh Data:" << pssh_data_base64_string;
114  RCHECK(Base64StringToBytes(pssh_data_base64_string, pssh_data));
115  return true;
116 }
117 
118 } // namespace
119 
120 namespace media {
121 
122 // A ref counted wrapper for EncryptionKeyMap.
123 class WidevineKeySource::RefCountedEncryptionKeyMap
124  : public base::RefCountedThreadSafe<RefCountedEncryptionKeyMap> {
125  public:
126  explicit RefCountedEncryptionKeyMap(EncryptionKeyMap* encryption_key_map) {
127  DCHECK(encryption_key_map);
128  encryption_key_map_.swap(*encryption_key_map);
129  }
130 
131  std::map<KeySource::TrackType, EncryptionKey*>& map() {
132  return encryption_key_map_;
133  }
134 
135  private:
136  friend class base::RefCountedThreadSafe<RefCountedEncryptionKeyMap>;
137 
138  ~RefCountedEncryptionKeyMap() { STLDeleteValues(&encryption_key_map_); }
139 
140  EncryptionKeyMap encryption_key_map_;
141 
142  DISALLOW_COPY_AND_ASSIGN(RefCountedEncryptionKeyMap);
143 };
144 
145 WidevineKeySource::WidevineKeySource(const std::string& server_url)
146  : key_production_thread_("KeyProductionThread",
147  base::Bind(&WidevineKeySource::FetchKeysTask,
148  base::Unretained(this))),
149  key_fetcher_(new HttpKeyFetcher(kKeyFetchTimeoutInSeconds)),
150  server_url_(server_url),
151  crypto_period_count_(kDefaultCryptoPeriodCount),
152  key_production_started_(false),
153  start_key_production_(false, false),
154  first_crypto_period_index_(0) {
155  key_production_thread_.Start();
156 }
157 
158 WidevineKeySource::~WidevineKeySource() {
159  if (key_pool_)
160  key_pool_->Stop();
161  if (key_production_thread_.HasBeenStarted()) {
162  // Signal the production thread to start key production if it is not
163  // signaled yet so the thread can be joined.
164  start_key_production_.Signal();
165  key_production_thread_.Join();
166  }
167  STLDeleteValues(&encryption_key_map_);
168 }
169 
170 Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& content_id,
171  const std::string& policy) {
172  base::AutoLock scoped_lock(lock_);
173  request_dict_.Clear();
174  std::string content_id_base64_string;
175  BytesToBase64String(content_id, &content_id_base64_string);
176  request_dict_.SetString("content_id", content_id_base64_string);
177  request_dict_.SetString("policy", policy);
178  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
179 }
180 
181 Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& pssh_box) {
182  const std::vector<uint8_t> widevine_system_id(
183  kWidevineSystemId, kWidevineSystemId + arraysize(kWidevineSystemId));
184 
186  if (!info.Parse(pssh_box.data(), pssh_box.size()))
187  return Status(error::PARSER_FAILURE, "Error parsing the PSSH box.");
188 
189  if (info.system_id() == widevine_system_id) {
190  base::AutoLock scoped_lock(lock_);
191  request_dict_.Clear();
192  std::string pssh_data_base64_string;
193 
194  BytesToBase64String(info.pssh_data(), &pssh_data_base64_string);
195  request_dict_.SetString("pssh_data", pssh_data_base64_string);
196  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
197  } else if (!info.key_ids().empty()) {
198  // This is not a Widevine PSSH box. Try making the request for the key-IDs.
199  // Even if this is a different key-system, it should still work. Either
200  // the server will not recognize it and return an error, or it will
201  // recognize it and the key must be correct (or the content is bad).
202  return FetchKeys(info.key_ids());
203  } else {
204  return Status(error::NOT_FOUND, "No key IDs given in PSSH box.");
205  }
206 }
207 
209  const std::vector<std::vector<uint8_t>>& key_ids) {
210  base::AutoLock scoped_lock(lock_);
211  request_dict_.Clear();
212  std::string pssh_data_base64_string;
213 
214  // Generate Widevine PSSH data from the key-IDs.
215  WidevinePsshData widevine_pssh_data;
216  for (size_t i = 0; i < key_ids.size(); i++) {
217  widevine_pssh_data.add_key_id(key_ids[i].data(), key_ids[i].size());
218  }
219 
220  const std::string serialized_string = widevine_pssh_data.SerializeAsString();
221  std::vector<uint8_t> pssh_data(serialized_string.begin(),
222  serialized_string.end());
223 
224  BytesToBase64String(pssh_data, &pssh_data_base64_string);
225  request_dict_.SetString("pssh_data", pssh_data_base64_string);
226  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
227 }
228 
230  base::AutoLock scoped_lock(lock_);
231  request_dict_.Clear();
232  // Javascript/JSON does not support int64_t or unsigned numbers. Use double
233  // instead as 32-bit integer can be lossless represented using double.
234  request_dict_.SetDouble("asset_id", asset_id);
235  return FetchKeysInternal(!kEnableKeyRotation, 0, true);
236 }
237 
238 Status WidevineKeySource::GetKey(TrackType track_type, EncryptionKey* key) {
239  DCHECK(key);
240  if (encryption_key_map_.find(track_type) == encryption_key_map_.end()) {
241  return Status(error::INTERNAL_ERROR,
242  "Cannot find key of type " + TrackTypeToString(track_type));
243  }
244  *key = *encryption_key_map_[track_type];
245  return Status::OK;
246 }
247 
248 Status WidevineKeySource::GetKey(const std::vector<uint8_t>& key_id,
249  EncryptionKey* key) {
250  DCHECK(key);
251  for (std::map<TrackType, EncryptionKey*>::iterator iter =
252  encryption_key_map_.begin();
253  iter != encryption_key_map_.end();
254  ++iter) {
255  if (iter->second->key_id == key_id) {
256  *key = *iter->second;
257  return Status::OK;
258  }
259  }
260  return Status(error::INTERNAL_ERROR,
261  "Cannot find key with specified key ID");
262 }
263 
264 Status WidevineKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
265  TrackType track_type,
266  EncryptionKey* key) {
267  DCHECK(key_production_thread_.HasBeenStarted());
268  // TODO(kqyang): This is not elegant. Consider refactoring later.
269  {
270  base::AutoLock scoped_lock(lock_);
271  if (!key_production_started_) {
272  // Another client may have a slightly smaller starting crypto period
273  // index. Set the initial value to account for that.
274  first_crypto_period_index_ =
275  crypto_period_index ? crypto_period_index - 1 : 0;
276  DCHECK(!key_pool_);
277  key_pool_.reset(new EncryptionKeyQueue(crypto_period_count_,
278  first_crypto_period_index_));
279  start_key_production_.Signal();
280  key_production_started_ = true;
281  }
282  }
283  return GetKeyInternal(crypto_period_index, track_type, key);
284 }
285 
286 void WidevineKeySource::set_signer(scoped_ptr<RequestSigner> signer) {
287  signer_ = signer.Pass();
288 }
289 
290 void WidevineKeySource::set_key_fetcher(scoped_ptr<KeyFetcher> key_fetcher) {
291  key_fetcher_ = key_fetcher.Pass();
292 }
293 
294 Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
295  TrackType track_type,
296  EncryptionKey* key) {
297  DCHECK(key_pool_);
298  DCHECK(key);
299  DCHECK_LE(track_type, NUM_VALID_TRACK_TYPES);
300  DCHECK_NE(track_type, TRACK_TYPE_UNKNOWN);
301 
302  scoped_refptr<RefCountedEncryptionKeyMap> ref_counted_encryption_key_map;
303  Status status =
304  key_pool_->Peek(crypto_period_index, &ref_counted_encryption_key_map,
305  kGetKeyTimeoutInSeconds * 1000);
306  if (!status.ok()) {
307  if (status.error_code() == error::STOPPED) {
308  CHECK(!common_encryption_request_status_.ok());
309  return common_encryption_request_status_;
310  }
311  return status;
312  }
313 
314  EncryptionKeyMap& encryption_key_map = ref_counted_encryption_key_map->map();
315  if (encryption_key_map.find(track_type) == encryption_key_map.end()) {
316  return Status(error::INTERNAL_ERROR,
317  "Cannot find key of type " + TrackTypeToString(track_type));
318  }
319  *key = *encryption_key_map[track_type];
320  return Status::OK;
321 }
322 
323 void WidevineKeySource::FetchKeysTask() {
324  // Wait until key production is signaled.
325  start_key_production_.Wait();
326  if (!key_pool_ || key_pool_->Stopped())
327  return;
328 
329  Status status = FetchKeysInternal(kEnableKeyRotation,
330  first_crypto_period_index_,
331  false);
332  while (status.ok()) {
333  first_crypto_period_index_ += crypto_period_count_;
334  status = FetchKeysInternal(kEnableKeyRotation,
335  first_crypto_period_index_,
336  false);
337  }
338  common_encryption_request_status_ = status;
339  key_pool_->Stop();
340 }
341 
342 Status WidevineKeySource::FetchKeysInternal(bool enable_key_rotation,
343  uint32_t first_crypto_period_index,
344  bool widevine_classic) {
345  std::string request;
346  FillRequest(enable_key_rotation,
347  first_crypto_period_index,
348  &request);
349 
350  std::string message;
351  Status status = GenerateKeyMessage(request, &message);
352  if (!status.ok())
353  return status;
354  VLOG(1) << "Message: " << message;
355 
356  std::string raw_response;
357  int64_t sleep_duration = kFirstRetryDelayMilliseconds;
358 
359  // Perform client side retries if seeing server transient error to workaround
360  // server limitation.
361  for (int i = 0; i < kNumTransientErrorRetries; ++i) {
362  status = key_fetcher_->FetchKeys(server_url_, message, &raw_response);
363  if (status.ok()) {
364  VLOG(1) << "Retry [" << i << "] Response:" << raw_response;
365 
366  std::string response;
367  if (!DecodeResponse(raw_response, &response)) {
368  return Status(error::SERVER_ERROR,
369  "Failed to decode response '" + raw_response + "'.");
370  }
371 
372  bool transient_error = false;
373  if (ExtractEncryptionKey(enable_key_rotation,
374  widevine_classic,
375  response,
376  &transient_error))
377  return Status::OK;
378 
379  if (!transient_error) {
380  return Status(
381  error::SERVER_ERROR,
382  "Failed to extract encryption key from '" + response + "'.");
383  }
384  } else if (status.error_code() != error::TIME_OUT) {
385  return status;
386  }
387 
388  // Exponential backoff.
389  if (i != kNumTransientErrorRetries - 1) {
390  base::PlatformThread::Sleep(
391  base::TimeDelta::FromMilliseconds(sleep_duration));
392  sleep_duration *= 2;
393  }
394  }
395  return Status(error::SERVER_ERROR,
396  "Failed to recover from server internal error.");
397 }
398 
399 void WidevineKeySource::FillRequest(bool enable_key_rotation,
400  uint32_t first_crypto_period_index,
401  std::string* request) {
402  DCHECK(request);
403  DCHECK(!request_dict_.empty());
404 
405  // Build tracks.
406  base::ListValue* tracks = new base::ListValue();
407 
408  base::DictionaryValue* track_sd = new base::DictionaryValue();
409  track_sd->SetString("type", "SD");
410  tracks->Append(track_sd);
411  base::DictionaryValue* track_hd = new base::DictionaryValue();
412  track_hd->SetString("type", "HD");
413  tracks->Append(track_hd);
414  base::DictionaryValue* track_audio = new base::DictionaryValue();
415  track_audio->SetString("type", "AUDIO");
416  tracks->Append(track_audio);
417 
418  request_dict_.Set("tracks", tracks);
419 
420  // Build DRM types.
421  base::ListValue* drm_types = new base::ListValue();
422  drm_types->AppendString("WIDEVINE");
423  request_dict_.Set("drm_types", drm_types);
424 
425  // Build key rotation fields.
426  if (enable_key_rotation) {
427  // Javascript/JSON does not support int64_t or unsigned numbers. Use double
428  // instead as 32-bit integer can be lossless represented using double.
429  request_dict_.SetDouble("first_crypto_period_index",
430  first_crypto_period_index);
431  request_dict_.SetInteger("crypto_period_count", crypto_period_count_);
432  }
433 
434  base::JSONWriter::WriteWithOptions(
435  request_dict_,
436  // Write doubles that have no fractional part as a normal integer, i.e.
437  // without using exponential notation or appending a '.0'.
438  base::JSONWriter::OPTIONS_OMIT_DOUBLE_TYPE_PRESERVATION, request);
439 }
440 
441 Status WidevineKeySource::GenerateKeyMessage(const std::string& request,
442  std::string* message) {
443  DCHECK(message);
444 
445  std::string request_base64_string;
446  base::Base64Encode(request, &request_base64_string);
447 
448  base::DictionaryValue request_dict;
449  request_dict.SetString("request", request_base64_string);
450 
451  // Sign the request.
452  if (signer_) {
453  std::string signature;
454  if (!signer_->GenerateSignature(request, &signature))
455  return Status(error::INTERNAL_ERROR, "Signature generation failed.");
456 
457  std::string signature_base64_string;
458  base::Base64Encode(signature, &signature_base64_string);
459 
460  request_dict.SetString("signature", signature_base64_string);
461  request_dict.SetString("signer", signer_->signer_name());
462  }
463 
464  base::JSONWriter::Write(request_dict, message);
465  return Status::OK;
466 }
467 
468 bool WidevineKeySource::DecodeResponse(
469  const std::string& raw_response,
470  std::string* response) {
471  DCHECK(response);
472 
473  // Extract base64 formatted response from JSON formatted raw response.
474  scoped_ptr<base::Value> root(base::JSONReader::Read(raw_response));
475  if (!root) {
476  LOG(ERROR) << "'" << raw_response << "' is not in JSON format.";
477  return false;
478  }
479  const base::DictionaryValue* response_dict = NULL;
480  RCHECK(root->GetAsDictionary(&response_dict));
481 
482  std::string response_base64_string;
483  RCHECK(response_dict->GetString("response", &response_base64_string));
484  RCHECK(base::Base64Decode(response_base64_string, response));
485  return true;
486 }
487 
488 bool WidevineKeySource::ExtractEncryptionKey(
489  bool enable_key_rotation,
490  bool widevine_classic,
491  const std::string& response,
492  bool* transient_error) {
493  DCHECK(transient_error);
494  *transient_error = false;
495 
496  scoped_ptr<base::Value> root(base::JSONReader::Read(response));
497  if (!root) {
498  LOG(ERROR) << "'" << response << "' is not in JSON format.";
499  return false;
500  }
501 
502  const base::DictionaryValue* license_dict = NULL;
503  RCHECK(root->GetAsDictionary(&license_dict));
504 
505  std::string license_status;
506  RCHECK(license_dict->GetString("status", &license_status));
507  if (license_status != kLicenseStatusOK) {
508  LOG(ERROR) << "Received non-OK license response: " << response;
509  *transient_error = (license_status == kLicenseStatusTransientError);
510  return false;
511  }
512 
513  const base::ListValue* tracks;
514  RCHECK(license_dict->GetList("tracks", &tracks));
515  // Should have at least one track per crypto_period.
516  RCHECK(enable_key_rotation ? tracks->GetSize() >= 1 * crypto_period_count_
517  : tracks->GetSize() >= 1);
518 
519  int current_crypto_period_index = first_crypto_period_index_;
520 
521  EncryptionKeyMap encryption_key_map;
522  for (size_t i = 0; i < tracks->GetSize(); ++i) {
523  const base::DictionaryValue* track_dict;
524  RCHECK(tracks->GetDictionary(i, &track_dict));
525 
526  if (enable_key_rotation) {
527  int crypto_period_index;
528  RCHECK(
529  track_dict->GetInteger("crypto_period_index", &crypto_period_index));
530  if (crypto_period_index != current_crypto_period_index) {
531  if (crypto_period_index != current_crypto_period_index + 1) {
532  LOG(ERROR) << "Expecting crypto period index "
533  << current_crypto_period_index << " or "
534  << current_crypto_period_index + 1 << "; Seen "
535  << crypto_period_index << " at track " << i;
536  return false;
537  }
538  if (!PushToKeyPool(&encryption_key_map))
539  return false;
540  ++current_crypto_period_index;
541  }
542  }
543 
544  std::string track_type_str;
545  RCHECK(track_dict->GetString("type", &track_type_str));
546  TrackType track_type = GetTrackTypeFromString(track_type_str);
547  DCHECK_NE(TRACK_TYPE_UNKNOWN, track_type);
548  RCHECK(encryption_key_map.find(track_type) == encryption_key_map.end());
549 
550  scoped_ptr<EncryptionKey> encryption_key(new EncryptionKey());
551 
552  if (!GetKeyFromTrack(*track_dict, &encryption_key->key))
553  return false;
554 
555  // Get key ID and PSSH data for CENC content only.
556  if (!widevine_classic) {
557  if (!GetKeyIdFromTrack(*track_dict, &encryption_key->key_id))
558  return false;
559 
560  ProtectionSystemSpecificInfo info;
561  info.add_key_id(encryption_key->key_id);
562  info.set_system_id(kWidevineSystemId, arraysize(kWidevineSystemId));
563  info.set_pssh_box_version(0);
564 
565  std::vector<uint8_t> pssh_data;
566  if (!GetPsshDataFromTrack(*track_dict, &pssh_data))
567  return false;
568  info.set_pssh_data(pssh_data);
569 
570  encryption_key->key_system_info.push_back(info);
571  }
572  encryption_key_map[track_type] = encryption_key.release();
573  }
574 
575  // NOTE: To support version 1 pssh, update ProtectionSystemSpecificInfo to
576  // include all key IDs in |encryption_key_map|.
577  DCHECK(!encryption_key_map.empty());
578  if (!enable_key_rotation) {
579  encryption_key_map_ = encryption_key_map;
580  return true;
581  }
582  return PushToKeyPool(&encryption_key_map);
583 }
584 
585 bool WidevineKeySource::PushToKeyPool(
586  EncryptionKeyMap* encryption_key_map) {
587  DCHECK(key_pool_);
588  DCHECK(encryption_key_map);
589  Status status =
590  key_pool_->Push(scoped_refptr<RefCountedEncryptionKeyMap>(
591  new RefCountedEncryptionKeyMap(encryption_key_map)),
592  kInfiniteTimeout);
593  encryption_key_map->clear();
594  if (!status.ok()) {
595  DCHECK_EQ(error::STOPPED, status.error_code());
596  return false;
597  }
598  return true;
599 }
600 
601 } // namespace media
602 } // namespace edash_packager
WidevineKeySource(const std::string &server_url)
void set_signer(scoped_ptr< RequestSigner > signer)
Status GetKey(TrackType track_type, EncryptionKey *key) override
Status FetchKeys(const std::vector< uint8_t > &pssh_box) override
void set_key_fetcher(scoped_ptr< KeyFetcher > key_fetcher)
Status GetCryptoPeriodKey(uint32_t crypto_period_index, TrackType track_type, EncryptionKey *key) override
static TrackType GetTrackTypeFromString(const std::string &track_type_string)
Convert string representation of track type to enum representation.
Definition: key_source.cc:19
static std::string TrackTypeToString(TrackType track_type)
Convert TrackType to string.
Definition: key_source.cc:33