DASH Media Packaging SDK
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator
widevine_key_source.cc
1 // Copyright 2014 Google Inc. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include "packager/media/base/widevine_key_source.h"
8 
9 #include <set>
10 
11 #include "packager/base/base64.h"
12 #include "packager/base/bind.h"
13 #include "packager/base/json/json_reader.h"
14 #include "packager/base/json/json_writer.h"
15 #include "packager/base/memory/ref_counted.h"
16 #include "packager/media/base/fixed_key_source.h"
17 #include "packager/media/base/http_key_fetcher.h"
18 #include "packager/media/base/producer_consumer_queue.h"
19 #include "packager/media/base/protection_system_specific_info.h"
20 #include "packager/media/base/rcheck.h"
21 #include "packager/media/base/request_signer.h"
22 #include "packager/media/base/widevine_pssh_data.pb.h"
23 
24 namespace shaka {
25 namespace {
26 
27 const bool kEnableKeyRotation = true;
28 
29 const char kLicenseStatusOK[] = "OK";
30 // Server may return INTERNAL_ERROR intermittently, which is a transient error
31 // and the next client request may succeed without problem.
32 const char kLicenseStatusTransientError[] = "INTERNAL_ERROR";
33 
34 // Number of times to retry requesting keys in case of a transient error from
35 // the server.
36 const int kNumTransientErrorRetries = 5;
37 const int kFirstRetryDelayMilliseconds = 1000;
38 
39 // Default crypto period count, which is the number of keys to fetch on every
40 // key rotation enabled request.
41 const int kDefaultCryptoPeriodCount = 10;
42 const int kGetKeyTimeoutInSeconds = 5 * 60; // 5 minutes.
43 const int kKeyFetchTimeoutInSeconds = 60; // 1 minute.
44 
45 bool Base64StringToBytes(const std::string& base64_string,
46  std::vector<uint8_t>* bytes) {
47  DCHECK(bytes);
48  std::string str;
49  if (!base::Base64Decode(base64_string, &str))
50  return false;
51  bytes->assign(str.begin(), str.end());
52  return true;
53 }
54 
55 void BytesToBase64String(const std::vector<uint8_t>& bytes,
56  std::string* base64_string) {
57  DCHECK(base64_string);
58  base::Base64Encode(base::StringPiece(reinterpret_cast<const char*>
59  (bytes.data()), bytes.size()),
60  base64_string);
61 }
62 
63 bool GetKeyFromTrack(const base::DictionaryValue& track_dict,
64  std::vector<uint8_t>* key) {
65  DCHECK(key);
66  std::string key_base64_string;
67  RCHECK(track_dict.GetString("key", &key_base64_string));
68  VLOG(2) << "Key:" << key_base64_string;
69  RCHECK(Base64StringToBytes(key_base64_string, key));
70  return true;
71 }
72 
73 bool GetKeyIdFromTrack(const base::DictionaryValue& track_dict,
74  std::vector<uint8_t>* key_id) {
75  DCHECK(key_id);
76  std::string key_id_base64_string;
77  RCHECK(track_dict.GetString("key_id", &key_id_base64_string));
78  VLOG(2) << "Keyid:" << key_id_base64_string;
79  RCHECK(Base64StringToBytes(key_id_base64_string, key_id));
80  return true;
81 }
82 
83 bool GetPsshDataFromTrack(const base::DictionaryValue& track_dict,
84  std::vector<uint8_t>* pssh_data) {
85  DCHECK(pssh_data);
86 
87  const base::ListValue* pssh_list;
88  RCHECK(track_dict.GetList("pssh", &pssh_list));
89  // Invariant check. We don't want to crash in release mode if possible.
90  // The following code handles it gracefully if GetSize() does not return 1.
91  DCHECK_EQ(1u, pssh_list->GetSize());
92 
93  const base::DictionaryValue* pssh_dict;
94  RCHECK(pssh_list->GetDictionary(0, &pssh_dict));
95  std::string drm_type;
96  RCHECK(pssh_dict->GetString("drm_type", &drm_type));
97  if (drm_type != "WIDEVINE") {
98  LOG(ERROR) << "Expecting drm_type 'WIDEVINE', get '" << drm_type << "'.";
99  return false;
100  }
101  std::string pssh_data_base64_string;
102  RCHECK(pssh_dict->GetString("data", &pssh_data_base64_string));
103 
104  VLOG(2) << "Pssh Data:" << pssh_data_base64_string;
105  RCHECK(Base64StringToBytes(pssh_data_base64_string, pssh_data));
106  return true;
107 }
108 
109 } // namespace
110 
111 namespace media {
112 
113 // A ref counted wrapper for EncryptionKeyMap.
114 class WidevineKeySource::RefCountedEncryptionKeyMap
115  : public base::RefCountedThreadSafe<RefCountedEncryptionKeyMap> {
116  public:
117  explicit RefCountedEncryptionKeyMap(EncryptionKeyMap* encryption_key_map) {
118  DCHECK(encryption_key_map);
119  encryption_key_map_.swap(*encryption_key_map);
120  }
121 
122  const EncryptionKeyMap& map() { return encryption_key_map_; }
123 
124  private:
125  friend class base::RefCountedThreadSafe<RefCountedEncryptionKeyMap>;
126 
127  ~RefCountedEncryptionKeyMap() {}
128 
129  EncryptionKeyMap encryption_key_map_;
130 
131  DISALLOW_COPY_AND_ASSIGN(RefCountedEncryptionKeyMap);
132 };
133 
134 WidevineKeySource::WidevineKeySource(const std::string& server_url,
135  bool add_common_pssh)
136  : key_production_thread_("KeyProductionThread",
137  base::Bind(&WidevineKeySource::FetchKeysTask,
138  base::Unretained(this))),
139  key_fetcher_(new HttpKeyFetcher(kKeyFetchTimeoutInSeconds)),
140  server_url_(server_url),
141  crypto_period_count_(kDefaultCryptoPeriodCount),
142  add_common_pssh_(add_common_pssh),
143  key_production_started_(false),
144  start_key_production_(base::WaitableEvent::ResetPolicy::AUTOMATIC,
145  base::WaitableEvent::InitialState::NOT_SIGNALED),
146  first_crypto_period_index_(0) {
147  key_production_thread_.Start();
148 }
149 
150 WidevineKeySource::~WidevineKeySource() {
151  if (key_pool_)
152  key_pool_->Stop();
153  if (key_production_thread_.HasBeenStarted()) {
154  // Signal the production thread to start key production if it is not
155  // signaled yet so the thread can be joined.
156  start_key_production_.Signal();
157  key_production_thread_.Join();
158  }
159 }
160 
161 Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& content_id,
162  const std::string& policy) {
163  base::AutoLock scoped_lock(lock_);
164  request_dict_.Clear();
165  std::string content_id_base64_string;
166  BytesToBase64String(content_id, &content_id_base64_string);
167  request_dict_.SetString("content_id", content_id_base64_string);
168  request_dict_.SetString("policy", policy);
169  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
170 }
171 
172 Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& pssh_box) {
173  const std::vector<uint8_t> widevine_system_id(
174  kWidevineSystemId, kWidevineSystemId + arraysize(kWidevineSystemId));
175 
177  if (!info.Parse(pssh_box.data(), pssh_box.size()))
178  return Status(error::PARSER_FAILURE, "Error parsing the PSSH box.");
179 
180  if (info.system_id() == widevine_system_id) {
181  base::AutoLock scoped_lock(lock_);
182  request_dict_.Clear();
183  std::string pssh_data_base64_string;
184 
185  BytesToBase64String(info.pssh_data(), &pssh_data_base64_string);
186  request_dict_.SetString("pssh_data", pssh_data_base64_string);
187  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
188  } else if (!info.key_ids().empty()) {
189  // This is not a Widevine PSSH box. Try making the request for the key-IDs.
190  // Even if this is a different key-system, it should still work. Either
191  // the server will not recognize it and return an error, or it will
192  // recognize it and the key must be correct (or the content is bad).
193  return FetchKeys(info.key_ids());
194  } else {
195  return Status(error::NOT_FOUND, "No key IDs given in PSSH box.");
196  }
197 }
198 
200  const std::vector<std::vector<uint8_t>>& key_ids) {
201  base::AutoLock scoped_lock(lock_);
202  request_dict_.Clear();
203  std::string pssh_data_base64_string;
204 
205  // Generate Widevine PSSH data from the key-IDs.
206  WidevinePsshData widevine_pssh_data;
207  for (size_t i = 0; i < key_ids.size(); i++) {
208  widevine_pssh_data.add_key_id(key_ids[i].data(), key_ids[i].size());
209  }
210 
211  const std::string serialized_string = widevine_pssh_data.SerializeAsString();
212  std::vector<uint8_t> pssh_data(serialized_string.begin(),
213  serialized_string.end());
214 
215  BytesToBase64String(pssh_data, &pssh_data_base64_string);
216  request_dict_.SetString("pssh_data", pssh_data_base64_string);
217  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
218 }
219 
221  base::AutoLock scoped_lock(lock_);
222  request_dict_.Clear();
223  // Javascript/JSON does not support int64_t or unsigned numbers. Use double
224  // instead as 32-bit integer can be lossless represented using double.
225  request_dict_.SetDouble("asset_id", asset_id);
226  return FetchKeysInternal(!kEnableKeyRotation, 0, true);
227 }
228 
229 Status WidevineKeySource::GetKey(TrackType track_type, EncryptionKey* key) {
230  DCHECK(key);
231  if (encryption_key_map_.find(track_type) == encryption_key_map_.end()) {
232  return Status(error::INTERNAL_ERROR,
233  "Cannot find key of type " + TrackTypeToString(track_type));
234  }
235  *key = *encryption_key_map_[track_type];
236  return Status::OK;
237 }
238 
239 Status WidevineKeySource::GetKey(const std::vector<uint8_t>& key_id,
240  EncryptionKey* key) {
241  DCHECK(key);
242  for (const auto& pair : encryption_key_map_) {
243  if (pair.second->key_id == key_id) {
244  *key = *pair.second;
245  return Status::OK;
246  }
247  }
248  return Status(error::INTERNAL_ERROR,
249  "Cannot find key with specified key ID");
250 }
251 
252 Status WidevineKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
253  TrackType track_type,
254  EncryptionKey* key) {
255  DCHECK(key_production_thread_.HasBeenStarted());
256  // TODO(kqyang): This is not elegant. Consider refactoring later.
257  {
258  base::AutoLock scoped_lock(lock_);
259  if (!key_production_started_) {
260  // Another client may have a slightly smaller starting crypto period
261  // index. Set the initial value to account for that.
262  first_crypto_period_index_ =
263  crypto_period_index ? crypto_period_index - 1 : 0;
264  DCHECK(!key_pool_);
265  key_pool_.reset(new EncryptionKeyQueue(crypto_period_count_,
266  first_crypto_period_index_));
267  start_key_production_.Signal();
268  key_production_started_ = true;
269  }
270  }
271  return GetKeyInternal(crypto_period_index, track_type, key);
272 }
273 
274 void WidevineKeySource::set_signer(std::unique_ptr<RequestSigner> signer) {
275  signer_ = std::move(signer);
276 }
277 
279  std::unique_ptr<KeyFetcher> key_fetcher) {
280  key_fetcher_ = std::move(key_fetcher);
281 }
282 
283 Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
284  TrackType track_type,
285  EncryptionKey* key) {
286  DCHECK(key_pool_);
287  DCHECK(key);
288  DCHECK_LE(track_type, NUM_VALID_TRACK_TYPES);
289  DCHECK_NE(track_type, TRACK_TYPE_UNKNOWN);
290 
291  scoped_refptr<RefCountedEncryptionKeyMap> ref_counted_encryption_key_map;
292  Status status =
293  key_pool_->Peek(crypto_period_index, &ref_counted_encryption_key_map,
294  kGetKeyTimeoutInSeconds * 1000);
295  if (!status.ok()) {
296  if (status.error_code() == error::STOPPED) {
297  CHECK(!common_encryption_request_status_.ok());
298  return common_encryption_request_status_;
299  }
300  return status;
301  }
302 
303  const EncryptionKeyMap& encryption_key_map =
304  ref_counted_encryption_key_map->map();
305  if (encryption_key_map.find(track_type) == encryption_key_map.end()) {
306  return Status(error::INTERNAL_ERROR,
307  "Cannot find key of type " + TrackTypeToString(track_type));
308  }
309  *key = *encryption_key_map.at(track_type);
310  return Status::OK;
311 }
312 
313 void WidevineKeySource::FetchKeysTask() {
314  // Wait until key production is signaled.
315  start_key_production_.Wait();
316  if (!key_pool_ || key_pool_->Stopped())
317  return;
318 
319  Status status = FetchKeysInternal(kEnableKeyRotation,
320  first_crypto_period_index_,
321  false);
322  while (status.ok()) {
323  first_crypto_period_index_ += crypto_period_count_;
324  status = FetchKeysInternal(kEnableKeyRotation,
325  first_crypto_period_index_,
326  false);
327  }
328  common_encryption_request_status_ = status;
329  key_pool_->Stop();
330 }
331 
332 Status WidevineKeySource::FetchKeysInternal(bool enable_key_rotation,
333  uint32_t first_crypto_period_index,
334  bool widevine_classic) {
335  std::string request;
336  FillRequest(enable_key_rotation,
337  first_crypto_period_index,
338  &request);
339 
340  std::string message;
341  Status status = GenerateKeyMessage(request, &message);
342  if (!status.ok())
343  return status;
344  VLOG(1) << "Message: " << message;
345 
346  std::string raw_response;
347  int64_t sleep_duration = kFirstRetryDelayMilliseconds;
348 
349  // Perform client side retries if seeing server transient error to workaround
350  // server limitation.
351  for (int i = 0; i < kNumTransientErrorRetries; ++i) {
352  status = key_fetcher_->FetchKeys(server_url_, message, &raw_response);
353  if (status.ok()) {
354  VLOG(1) << "Retry [" << i << "] Response:" << raw_response;
355 
356  std::string response;
357  if (!DecodeResponse(raw_response, &response)) {
358  return Status(error::SERVER_ERROR,
359  "Failed to decode response '" + raw_response + "'.");
360  }
361 
362  bool transient_error = false;
363  if (ExtractEncryptionKey(enable_key_rotation,
364  widevine_classic,
365  response,
366  &transient_error))
367  return Status::OK;
368 
369  if (!transient_error) {
370  return Status(
371  error::SERVER_ERROR,
372  "Failed to extract encryption key from '" + response + "'.");
373  }
374  } else if (status.error_code() != error::TIME_OUT) {
375  return status;
376  }
377 
378  // Exponential backoff.
379  if (i != kNumTransientErrorRetries - 1) {
380  base::PlatformThread::Sleep(
381  base::TimeDelta::FromMilliseconds(sleep_duration));
382  sleep_duration *= 2;
383  }
384  }
385  return Status(error::SERVER_ERROR,
386  "Failed to recover from server internal error.");
387 }
388 
389 void WidevineKeySource::FillRequest(bool enable_key_rotation,
390  uint32_t first_crypto_period_index,
391  std::string* request) {
392  DCHECK(request);
393  DCHECK(!request_dict_.empty());
394 
395  // Build tracks.
396  base::ListValue* tracks = new base::ListValue();
397 
398  base::DictionaryValue* track_sd = new base::DictionaryValue();
399  track_sd->SetString("type", "SD");
400  tracks->Append(track_sd);
401  base::DictionaryValue* track_hd = new base::DictionaryValue();
402  track_hd->SetString("type", "HD");
403  tracks->Append(track_hd);
404  base::DictionaryValue* track_audio = new base::DictionaryValue();
405  track_audio->SetString("type", "AUDIO");
406  tracks->Append(track_audio);
407 
408  request_dict_.Set("tracks", tracks);
409 
410  // Build DRM types.
411  base::ListValue* drm_types = new base::ListValue();
412  drm_types->AppendString("WIDEVINE");
413  request_dict_.Set("drm_types", drm_types);
414 
415  // Build key rotation fields.
416  if (enable_key_rotation) {
417  // Javascript/JSON does not support int64_t or unsigned numbers. Use double
418  // instead as 32-bit integer can be lossless represented using double.
419  request_dict_.SetDouble("first_crypto_period_index",
420  first_crypto_period_index);
421  request_dict_.SetInteger("crypto_period_count", crypto_period_count_);
422  }
423 
424  base::JSONWriter::WriteWithOptions(
425  request_dict_,
426  // Write doubles that have no fractional part as a normal integer, i.e.
427  // without using exponential notation or appending a '.0'.
428  base::JSONWriter::OPTIONS_OMIT_DOUBLE_TYPE_PRESERVATION, request);
429 }
430 
431 Status WidevineKeySource::GenerateKeyMessage(const std::string& request,
432  std::string* message) {
433  DCHECK(message);
434 
435  std::string request_base64_string;
436  base::Base64Encode(request, &request_base64_string);
437 
438  base::DictionaryValue request_dict;
439  request_dict.SetString("request", request_base64_string);
440 
441  // Sign the request.
442  if (signer_) {
443  std::string signature;
444  if (!signer_->GenerateSignature(request, &signature))
445  return Status(error::INTERNAL_ERROR, "Signature generation failed.");
446 
447  std::string signature_base64_string;
448  base::Base64Encode(signature, &signature_base64_string);
449 
450  request_dict.SetString("signature", signature_base64_string);
451  request_dict.SetString("signer", signer_->signer_name());
452  }
453 
454  base::JSONWriter::Write(request_dict, message);
455  return Status::OK;
456 }
457 
458 bool WidevineKeySource::DecodeResponse(
459  const std::string& raw_response,
460  std::string* response) {
461  DCHECK(response);
462 
463  // Extract base64 formatted response from JSON formatted raw response.
464  std::unique_ptr<base::Value> root(base::JSONReader::Read(raw_response));
465  if (!root) {
466  LOG(ERROR) << "'" << raw_response << "' is not in JSON format.";
467  return false;
468  }
469  const base::DictionaryValue* response_dict = NULL;
470  RCHECK(root->GetAsDictionary(&response_dict));
471 
472  std::string response_base64_string;
473  RCHECK(response_dict->GetString("response", &response_base64_string));
474  RCHECK(base::Base64Decode(response_base64_string, response));
475  return true;
476 }
477 
478 bool WidevineKeySource::ExtractEncryptionKey(
479  bool enable_key_rotation,
480  bool widevine_classic,
481  const std::string& response,
482  bool* transient_error) {
483  DCHECK(transient_error);
484  *transient_error = false;
485 
486  std::unique_ptr<base::Value> root(base::JSONReader::Read(response));
487  if (!root) {
488  LOG(ERROR) << "'" << response << "' is not in JSON format.";
489  return false;
490  }
491 
492  const base::DictionaryValue* license_dict = NULL;
493  RCHECK(root->GetAsDictionary(&license_dict));
494 
495  std::string license_status;
496  RCHECK(license_dict->GetString("status", &license_status));
497  if (license_status != kLicenseStatusOK) {
498  LOG(ERROR) << "Received non-OK license response: " << response;
499  *transient_error = (license_status == kLicenseStatusTransientError);
500  return false;
501  }
502 
503  const base::ListValue* tracks;
504  RCHECK(license_dict->GetList("tracks", &tracks));
505  // Should have at least one track per crypto_period.
506  RCHECK(enable_key_rotation ? tracks->GetSize() >= 1 * crypto_period_count_
507  : tracks->GetSize() >= 1);
508 
509  int current_crypto_period_index = first_crypto_period_index_;
510 
511  EncryptionKeyMap encryption_key_map;
512  for (size_t i = 0; i < tracks->GetSize(); ++i) {
513  const base::DictionaryValue* track_dict;
514  RCHECK(tracks->GetDictionary(i, &track_dict));
515 
516  if (enable_key_rotation) {
517  int crypto_period_index;
518  RCHECK(
519  track_dict->GetInteger("crypto_period_index", &crypto_period_index));
520  if (crypto_period_index != current_crypto_period_index) {
521  if (crypto_period_index != current_crypto_period_index + 1) {
522  LOG(ERROR) << "Expecting crypto period index "
523  << current_crypto_period_index << " or "
524  << current_crypto_period_index + 1 << "; Seen "
525  << crypto_period_index << " at track " << i;
526  return false;
527  }
528  if (!PushToKeyPool(&encryption_key_map))
529  return false;
530  ++current_crypto_period_index;
531  }
532  }
533 
534  std::string track_type_str;
535  RCHECK(track_dict->GetString("type", &track_type_str));
536  TrackType track_type = GetTrackTypeFromString(track_type_str);
537  DCHECK_NE(TRACK_TYPE_UNKNOWN, track_type);
538  RCHECK(encryption_key_map.find(track_type) == encryption_key_map.end());
539 
540  std::unique_ptr<EncryptionKey> encryption_key(new EncryptionKey());
541 
542  if (!GetKeyFromTrack(*track_dict, &encryption_key->key))
543  return false;
544 
545  // Get key ID and PSSH data for CENC content only.
546  if (!widevine_classic) {
547  if (!GetKeyIdFromTrack(*track_dict, &encryption_key->key_id))
548  return false;
549 
550  ProtectionSystemSpecificInfo info;
551  info.add_key_id(encryption_key->key_id);
552  info.set_system_id(kWidevineSystemId, arraysize(kWidevineSystemId));
553  info.set_pssh_box_version(0);
554 
555  std::vector<uint8_t> pssh_data;
556  if (!GetPsshDataFromTrack(*track_dict, &pssh_data))
557  return false;
558  info.set_pssh_data(pssh_data);
559 
560  encryption_key->key_system_info.push_back(info);
561  }
562  encryption_key_map[track_type] = std::move(encryption_key);
563  }
564 
565  // If the flag exists, create a common system ID PSSH box that contains the
566  // key IDs of all the keys.
567  if (add_common_pssh_ && !widevine_classic) {
568  std::set<std::vector<uint8_t>> key_ids;
569  for (const EncryptionKeyMap::value_type& pair : encryption_key_map) {
570  key_ids.insert(pair.second->key_id);
571  }
572 
573  // Create a common system PSSH box.
574  ProtectionSystemSpecificInfo info;
575  info.set_system_id(kCommonSystemId, arraysize(kCommonSystemId));
576  info.set_pssh_box_version(1);
577  for (const std::vector<uint8_t>& key_id : key_ids) {
578  info.add_key_id(key_id);
579  }
580 
581  for (const EncryptionKeyMap::value_type& pair : encryption_key_map) {
582  pair.second->key_system_info.push_back(info);
583  }
584  }
585 
586  DCHECK(!encryption_key_map.empty());
587  if (!enable_key_rotation) {
588  encryption_key_map_.swap(encryption_key_map);
589  return true;
590  }
591  return PushToKeyPool(&encryption_key_map);
592 }
593 
594 bool WidevineKeySource::PushToKeyPool(
595  EncryptionKeyMap* encryption_key_map) {
596  DCHECK(key_pool_);
597  DCHECK(encryption_key_map);
598  Status status =
599  key_pool_->Push(scoped_refptr<RefCountedEncryptionKeyMap>(
600  new RefCountedEncryptionKeyMap(encryption_key_map)),
601  kInfiniteTimeout);
602  encryption_key_map->clear();
603  if (!status.ok()) {
604  DCHECK_EQ(error::STOPPED, status.error_code());
605  return false;
606  }
607  return true;
608 }
609 
610 } // namespace media
611 } // namespace shaka
Status FetchKeys(const std::vector< uint8_t > &pssh_box) override
Status GetCryptoPeriodKey(uint32_t crypto_period_index, TrackType track_type, EncryptionKey *key) override
Status GetKey(TrackType track_type, EncryptionKey *key) override
static std::string TrackTypeToString(TrackType track_type)
Convert TrackType to string.
Definition: key_source.cc:33
static TrackType GetTrackTypeFromString(const std::string &track_type_string)
Convert string representation of track type to enum representation.
Definition: key_source.cc:19
void set_key_fetcher(std::unique_ptr< KeyFetcher > key_fetcher)
WidevineKeySource(const std::string &server_url, bool add_common_pssh)
void set_signer(std::unique_ptr< RequestSigner > signer)
bool Parse(const uint8_t *data, size_t data_size)