DASH Media Packaging SDK
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator
widevine_key_source.cc
1 // Copyright 2014 Google Inc. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include "packager/media/base/widevine_key_source.h"
8 
9 #include <set>
10 
11 #include "packager/base/base64.h"
12 #include "packager/base/bind.h"
13 #include "packager/base/json/json_reader.h"
14 #include "packager/base/json/json_writer.h"
15 #include "packager/media/base/fixed_key_source.h"
16 #include "packager/media/base/http_key_fetcher.h"
17 #include "packager/media/base/producer_consumer_queue.h"
18 #include "packager/media/base/protection_system_specific_info.h"
19 #include "packager/media/base/rcheck.h"
20 #include "packager/media/base/request_signer.h"
21 #include "packager/media/base/widevine_pssh_data.pb.h"
22 
23 namespace shaka {
24 namespace {
25 
26 const bool kEnableKeyRotation = true;
27 
28 const char kLicenseStatusOK[] = "OK";
29 // Server may return INTERNAL_ERROR intermittently, which is a transient error
30 // and the next client request may succeed without problem.
31 const char kLicenseStatusTransientError[] = "INTERNAL_ERROR";
32 
33 // Number of times to retry requesting keys in case of a transient error from
34 // the server.
35 const int kNumTransientErrorRetries = 5;
36 const int kFirstRetryDelayMilliseconds = 1000;
37 
38 // Default crypto period count, which is the number of keys to fetch on every
39 // key rotation enabled request.
40 const int kDefaultCryptoPeriodCount = 10;
41 const int kGetKeyTimeoutInSeconds = 5 * 60; // 5 minutes.
42 const int kKeyFetchTimeoutInSeconds = 60; // 1 minute.
43 
44 bool Base64StringToBytes(const std::string& base64_string,
45  std::vector<uint8_t>* bytes) {
46  DCHECK(bytes);
47  std::string str;
48  if (!base::Base64Decode(base64_string, &str))
49  return false;
50  bytes->assign(str.begin(), str.end());
51  return true;
52 }
53 
54 void BytesToBase64String(const std::vector<uint8_t>& bytes,
55  std::string* base64_string) {
56  DCHECK(base64_string);
57  base::Base64Encode(base::StringPiece(reinterpret_cast<const char*>
58  (bytes.data()), bytes.size()),
59  base64_string);
60 }
61 
62 bool GetKeyFromTrack(const base::DictionaryValue& track_dict,
63  std::vector<uint8_t>* key) {
64  DCHECK(key);
65  std::string key_base64_string;
66  RCHECK(track_dict.GetString("key", &key_base64_string));
67  VLOG(2) << "Key:" << key_base64_string;
68  RCHECK(Base64StringToBytes(key_base64_string, key));
69  return true;
70 }
71 
72 bool GetKeyIdFromTrack(const base::DictionaryValue& track_dict,
73  std::vector<uint8_t>* key_id) {
74  DCHECK(key_id);
75  std::string key_id_base64_string;
76  RCHECK(track_dict.GetString("key_id", &key_id_base64_string));
77  VLOG(2) << "Keyid:" << key_id_base64_string;
78  RCHECK(Base64StringToBytes(key_id_base64_string, key_id));
79  return true;
80 }
81 
82 bool GetPsshDataFromTrack(const base::DictionaryValue& track_dict,
83  std::vector<uint8_t>* pssh_data) {
84  DCHECK(pssh_data);
85 
86  const base::ListValue* pssh_list;
87  RCHECK(track_dict.GetList("pssh", &pssh_list));
88  // Invariant check. We don't want to crash in release mode if possible.
89  // The following code handles it gracefully if GetSize() does not return 1.
90  DCHECK_EQ(1u, pssh_list->GetSize());
91 
92  const base::DictionaryValue* pssh_dict;
93  RCHECK(pssh_list->GetDictionary(0, &pssh_dict));
94  std::string drm_type;
95  RCHECK(pssh_dict->GetString("drm_type", &drm_type));
96  if (drm_type != "WIDEVINE") {
97  LOG(ERROR) << "Expecting drm_type 'WIDEVINE', get '" << drm_type << "'.";
98  return false;
99  }
100  std::string pssh_data_base64_string;
101  RCHECK(pssh_dict->GetString("data", &pssh_data_base64_string));
102 
103  VLOG(2) << "Pssh Data:" << pssh_data_base64_string;
104  RCHECK(Base64StringToBytes(pssh_data_base64_string, pssh_data));
105  return true;
106 }
107 
108 } // namespace
109 
110 namespace media {
111 
112 WidevineKeySource::WidevineKeySource(const std::string& server_url,
113  bool add_common_pssh)
114  : key_production_thread_("KeyProductionThread",
115  base::Bind(&WidevineKeySource::FetchKeysTask,
116  base::Unretained(this))),
117  key_fetcher_(new HttpKeyFetcher(kKeyFetchTimeoutInSeconds)),
118  server_url_(server_url),
119  crypto_period_count_(kDefaultCryptoPeriodCount),
120  add_common_pssh_(add_common_pssh),
121  key_production_started_(false),
122  start_key_production_(base::WaitableEvent::ResetPolicy::AUTOMATIC,
123  base::WaitableEvent::InitialState::NOT_SIGNALED),
124  first_crypto_period_index_(0) {
125  key_production_thread_.Start();
126 }
127 
128 WidevineKeySource::~WidevineKeySource() {
129  if (key_pool_)
130  key_pool_->Stop();
131  if (key_production_thread_.HasBeenStarted()) {
132  // Signal the production thread to start key production if it is not
133  // signaled yet so the thread can be joined.
134  start_key_production_.Signal();
135  key_production_thread_.Join();
136  }
137 }
138 
139 Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& content_id,
140  const std::string& policy) {
141  base::AutoLock scoped_lock(lock_);
142  request_dict_.Clear();
143  std::string content_id_base64_string;
144  BytesToBase64String(content_id, &content_id_base64_string);
145  request_dict_.SetString("content_id", content_id_base64_string);
146  request_dict_.SetString("policy", policy);
147  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
148 }
149 
150 Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& pssh_box) {
151  const std::vector<uint8_t> widevine_system_id(
152  kWidevineSystemId, kWidevineSystemId + arraysize(kWidevineSystemId));
153 
155  if (!info.Parse(pssh_box.data(), pssh_box.size()))
156  return Status(error::PARSER_FAILURE, "Error parsing the PSSH box.");
157 
158  if (info.system_id() == widevine_system_id) {
159  base::AutoLock scoped_lock(lock_);
160  request_dict_.Clear();
161  std::string pssh_data_base64_string;
162 
163  BytesToBase64String(info.pssh_data(), &pssh_data_base64_string);
164  request_dict_.SetString("pssh_data", pssh_data_base64_string);
165  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
166  } else if (!info.key_ids().empty()) {
167  // This is not a Widevine PSSH box. Try making the request for the key-IDs.
168  // Even if this is a different key-system, it should still work. Either
169  // the server will not recognize it and return an error, or it will
170  // recognize it and the key must be correct (or the content is bad).
171  return FetchKeys(info.key_ids());
172  } else {
173  return Status(error::NOT_FOUND, "No key IDs given in PSSH box.");
174  }
175 }
176 
178  const std::vector<std::vector<uint8_t>>& key_ids) {
179  base::AutoLock scoped_lock(lock_);
180  request_dict_.Clear();
181  std::string pssh_data_base64_string;
182 
183  // Generate Widevine PSSH data from the key-IDs.
184  WidevinePsshData widevine_pssh_data;
185  for (size_t i = 0; i < key_ids.size(); i++) {
186  widevine_pssh_data.add_key_id(key_ids[i].data(), key_ids[i].size());
187  }
188 
189  const std::string serialized_string = widevine_pssh_data.SerializeAsString();
190  std::vector<uint8_t> pssh_data(serialized_string.begin(),
191  serialized_string.end());
192 
193  BytesToBase64String(pssh_data, &pssh_data_base64_string);
194  request_dict_.SetString("pssh_data", pssh_data_base64_string);
195  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
196 }
197 
199  base::AutoLock scoped_lock(lock_);
200  request_dict_.Clear();
201  // Javascript/JSON does not support int64_t or unsigned numbers. Use double
202  // instead as 32-bit integer can be lossless represented using double.
203  request_dict_.SetDouble("asset_id", asset_id);
204  return FetchKeysInternal(!kEnableKeyRotation, 0, true);
205 }
206 
207 Status WidevineKeySource::GetKey(TrackType track_type, EncryptionKey* key) {
208  DCHECK(key);
209  if (encryption_key_map_.find(track_type) == encryption_key_map_.end()) {
210  return Status(error::INTERNAL_ERROR,
211  "Cannot find key of type " + TrackTypeToString(track_type));
212  }
213  *key = *encryption_key_map_[track_type];
214  return Status::OK;
215 }
216 
217 Status WidevineKeySource::GetKey(const std::vector<uint8_t>& key_id,
218  EncryptionKey* key) {
219  DCHECK(key);
220  for (const auto& pair : encryption_key_map_) {
221  if (pair.second->key_id == key_id) {
222  *key = *pair.second;
223  return Status::OK;
224  }
225  }
226  return Status(error::INTERNAL_ERROR,
227  "Cannot find key with specified key ID");
228 }
229 
230 Status WidevineKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
231  TrackType track_type,
232  EncryptionKey* key) {
233  DCHECK(key_production_thread_.HasBeenStarted());
234  // TODO(kqyang): This is not elegant. Consider refactoring later.
235  {
236  base::AutoLock scoped_lock(lock_);
237  if (!key_production_started_) {
238  // Another client may have a slightly smaller starting crypto period
239  // index. Set the initial value to account for that.
240  first_crypto_period_index_ =
241  crypto_period_index ? crypto_period_index - 1 : 0;
242  DCHECK(!key_pool_);
243  key_pool_.reset(new EncryptionKeyQueue(crypto_period_count_,
244  first_crypto_period_index_));
245  start_key_production_.Signal();
246  key_production_started_ = true;
247  }
248  }
249  return GetKeyInternal(crypto_period_index, track_type, key);
250 }
251 
252 void WidevineKeySource::set_signer(std::unique_ptr<RequestSigner> signer) {
253  signer_ = std::move(signer);
254 }
255 
257  std::unique_ptr<KeyFetcher> key_fetcher) {
258  key_fetcher_ = std::move(key_fetcher);
259 }
260 
261 Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
262  TrackType track_type,
263  EncryptionKey* key) {
264  DCHECK(key_pool_);
265  DCHECK(key);
266  DCHECK_LE(track_type, NUM_VALID_TRACK_TYPES);
267  DCHECK_NE(track_type, TRACK_TYPE_UNKNOWN);
268 
269  std::shared_ptr<EncryptionKeyMap> encryption_key_map;
270  Status status = key_pool_->Peek(crypto_period_index, &encryption_key_map,
271  kGetKeyTimeoutInSeconds * 1000);
272  if (!status.ok()) {
273  if (status.error_code() == error::STOPPED) {
274  CHECK(!common_encryption_request_status_.ok());
275  return common_encryption_request_status_;
276  }
277  return status;
278  }
279 
280  if (encryption_key_map->find(track_type) == encryption_key_map->end()) {
281  return Status(error::INTERNAL_ERROR,
282  "Cannot find key of type " + TrackTypeToString(track_type));
283  }
284  *key = *encryption_key_map->at(track_type);
285  return Status::OK;
286 }
287 
288 void WidevineKeySource::FetchKeysTask() {
289  // Wait until key production is signaled.
290  start_key_production_.Wait();
291  if (!key_pool_ || key_pool_->Stopped())
292  return;
293 
294  Status status = FetchKeysInternal(kEnableKeyRotation,
295  first_crypto_period_index_,
296  false);
297  while (status.ok()) {
298  first_crypto_period_index_ += crypto_period_count_;
299  status = FetchKeysInternal(kEnableKeyRotation,
300  first_crypto_period_index_,
301  false);
302  }
303  common_encryption_request_status_ = status;
304  key_pool_->Stop();
305 }
306 
307 Status WidevineKeySource::FetchKeysInternal(bool enable_key_rotation,
308  uint32_t first_crypto_period_index,
309  bool widevine_classic) {
310  std::string request;
311  FillRequest(enable_key_rotation,
312  first_crypto_period_index,
313  &request);
314 
315  std::string message;
316  Status status = GenerateKeyMessage(request, &message);
317  if (!status.ok())
318  return status;
319  VLOG(1) << "Message: " << message;
320 
321  std::string raw_response;
322  int64_t sleep_duration = kFirstRetryDelayMilliseconds;
323 
324  // Perform client side retries if seeing server transient error to workaround
325  // server limitation.
326  for (int i = 0; i < kNumTransientErrorRetries; ++i) {
327  status = key_fetcher_->FetchKeys(server_url_, message, &raw_response);
328  if (status.ok()) {
329  VLOG(1) << "Retry [" << i << "] Response:" << raw_response;
330 
331  std::string response;
332  if (!DecodeResponse(raw_response, &response)) {
333  return Status(error::SERVER_ERROR,
334  "Failed to decode response '" + raw_response + "'.");
335  }
336 
337  bool transient_error = false;
338  if (ExtractEncryptionKey(enable_key_rotation,
339  widevine_classic,
340  response,
341  &transient_error))
342  return Status::OK;
343 
344  if (!transient_error) {
345  return Status(
346  error::SERVER_ERROR,
347  "Failed to extract encryption key from '" + response + "'.");
348  }
349  } else if (status.error_code() != error::TIME_OUT) {
350  return status;
351  }
352 
353  // Exponential backoff.
354  if (i != kNumTransientErrorRetries - 1) {
355  base::PlatformThread::Sleep(
356  base::TimeDelta::FromMilliseconds(sleep_duration));
357  sleep_duration *= 2;
358  }
359  }
360  return Status(error::SERVER_ERROR,
361  "Failed to recover from server internal error.");
362 }
363 
364 void WidevineKeySource::FillRequest(bool enable_key_rotation,
365  uint32_t first_crypto_period_index,
366  std::string* request) {
367  DCHECK(request);
368  DCHECK(!request_dict_.empty());
369 
370  // Build tracks.
371  base::ListValue* tracks = new base::ListValue();
372 
373  base::DictionaryValue* track_sd = new base::DictionaryValue();
374  track_sd->SetString("type", "SD");
375  tracks->Append(track_sd);
376  base::DictionaryValue* track_hd = new base::DictionaryValue();
377  track_hd->SetString("type", "HD");
378  tracks->Append(track_hd);
379  base::DictionaryValue* track_uhd1 = new base::DictionaryValue();
380  track_uhd1->SetString("type", "UHD1");
381  tracks->Append(track_uhd1);
382  base::DictionaryValue* track_uhd2 = new base::DictionaryValue();
383  track_uhd2->SetString("type", "UHD2");
384  tracks->Append(track_uhd2);
385  base::DictionaryValue* track_audio = new base::DictionaryValue();
386  track_audio->SetString("type", "AUDIO");
387  tracks->Append(track_audio);
388 
389  request_dict_.Set("tracks", tracks);
390 
391  // Build DRM types.
392  base::ListValue* drm_types = new base::ListValue();
393  drm_types->AppendString("WIDEVINE");
394  request_dict_.Set("drm_types", drm_types);
395 
396  // Build key rotation fields.
397  if (enable_key_rotation) {
398  // Javascript/JSON does not support int64_t or unsigned numbers. Use double
399  // instead as 32-bit integer can be lossless represented using double.
400  request_dict_.SetDouble("first_crypto_period_index",
401  first_crypto_period_index);
402  request_dict_.SetInteger("crypto_period_count", crypto_period_count_);
403  }
404 
405  base::JSONWriter::WriteWithOptions(
406  request_dict_,
407  // Write doubles that have no fractional part as a normal integer, i.e.
408  // without using exponential notation or appending a '.0'.
409  base::JSONWriter::OPTIONS_OMIT_DOUBLE_TYPE_PRESERVATION, request);
410 }
411 
412 Status WidevineKeySource::GenerateKeyMessage(const std::string& request,
413  std::string* message) {
414  DCHECK(message);
415 
416  std::string request_base64_string;
417  base::Base64Encode(request, &request_base64_string);
418 
419  base::DictionaryValue request_dict;
420  request_dict.SetString("request", request_base64_string);
421 
422  // Sign the request.
423  if (signer_) {
424  std::string signature;
425  if (!signer_->GenerateSignature(request, &signature))
426  return Status(error::INTERNAL_ERROR, "Signature generation failed.");
427 
428  std::string signature_base64_string;
429  base::Base64Encode(signature, &signature_base64_string);
430 
431  request_dict.SetString("signature", signature_base64_string);
432  request_dict.SetString("signer", signer_->signer_name());
433  }
434 
435  base::JSONWriter::Write(request_dict, message);
436  return Status::OK;
437 }
438 
439 bool WidevineKeySource::DecodeResponse(
440  const std::string& raw_response,
441  std::string* response) {
442  DCHECK(response);
443 
444  // Extract base64 formatted response from JSON formatted raw response.
445  std::unique_ptr<base::Value> root(base::JSONReader::Read(raw_response));
446  if (!root) {
447  LOG(ERROR) << "'" << raw_response << "' is not in JSON format.";
448  return false;
449  }
450  const base::DictionaryValue* response_dict = NULL;
451  RCHECK(root->GetAsDictionary(&response_dict));
452 
453  std::string response_base64_string;
454  RCHECK(response_dict->GetString("response", &response_base64_string));
455  RCHECK(base::Base64Decode(response_base64_string, response));
456  return true;
457 }
458 
459 bool WidevineKeySource::ExtractEncryptionKey(
460  bool enable_key_rotation,
461  bool widevine_classic,
462  const std::string& response,
463  bool* transient_error) {
464  DCHECK(transient_error);
465  *transient_error = false;
466 
467  std::unique_ptr<base::Value> root(base::JSONReader::Read(response));
468  if (!root) {
469  LOG(ERROR) << "'" << response << "' is not in JSON format.";
470  return false;
471  }
472 
473  const base::DictionaryValue* license_dict = NULL;
474  RCHECK(root->GetAsDictionary(&license_dict));
475 
476  std::string license_status;
477  RCHECK(license_dict->GetString("status", &license_status));
478  if (license_status != kLicenseStatusOK) {
479  LOG(ERROR) << "Received non-OK license response: " << response;
480  *transient_error = (license_status == kLicenseStatusTransientError);
481  return false;
482  }
483 
484  const base::ListValue* tracks;
485  RCHECK(license_dict->GetList("tracks", &tracks));
486  // Should have at least one track per crypto_period.
487  RCHECK(enable_key_rotation ? tracks->GetSize() >= 1 * crypto_period_count_
488  : tracks->GetSize() >= 1);
489 
490  int current_crypto_period_index = first_crypto_period_index_;
491 
492  EncryptionKeyMap encryption_key_map;
493  for (size_t i = 0; i < tracks->GetSize(); ++i) {
494  const base::DictionaryValue* track_dict;
495  RCHECK(tracks->GetDictionary(i, &track_dict));
496 
497  if (enable_key_rotation) {
498  int crypto_period_index;
499  RCHECK(
500  track_dict->GetInteger("crypto_period_index", &crypto_period_index));
501  if (crypto_period_index != current_crypto_period_index) {
502  if (crypto_period_index != current_crypto_period_index + 1) {
503  LOG(ERROR) << "Expecting crypto period index "
504  << current_crypto_period_index << " or "
505  << current_crypto_period_index + 1 << "; Seen "
506  << crypto_period_index << " at track " << i;
507  return false;
508  }
509  if (!PushToKeyPool(&encryption_key_map))
510  return false;
511  ++current_crypto_period_index;
512  }
513  }
514 
515  std::string track_type_str;
516  RCHECK(track_dict->GetString("type", &track_type_str));
517  TrackType track_type = GetTrackTypeFromString(track_type_str);
518  DCHECK_NE(TRACK_TYPE_UNKNOWN, track_type);
519  RCHECK(encryption_key_map.find(track_type) == encryption_key_map.end());
520 
521  std::unique_ptr<EncryptionKey> encryption_key(new EncryptionKey());
522 
523  if (!GetKeyFromTrack(*track_dict, &encryption_key->key))
524  return false;
525 
526  // Get key ID and PSSH data for CENC content only.
527  if (!widevine_classic) {
528  if (!GetKeyIdFromTrack(*track_dict, &encryption_key->key_id))
529  return false;
530 
531  ProtectionSystemSpecificInfo info;
532  info.add_key_id(encryption_key->key_id);
533  info.set_system_id(kWidevineSystemId, arraysize(kWidevineSystemId));
534  info.set_pssh_box_version(0);
535 
536  std::vector<uint8_t> pssh_data;
537  if (!GetPsshDataFromTrack(*track_dict, &pssh_data))
538  return false;
539  info.set_pssh_data(pssh_data);
540 
541  encryption_key->key_system_info.push_back(info);
542  }
543  encryption_key_map[track_type] = std::move(encryption_key);
544  }
545 
546  // If the flag exists, create a common system ID PSSH box that contains the
547  // key IDs of all the keys.
548  if (add_common_pssh_ && !widevine_classic) {
549  std::set<std::vector<uint8_t>> key_ids;
550  for (const EncryptionKeyMap::value_type& pair : encryption_key_map) {
551  key_ids.insert(pair.second->key_id);
552  }
553 
554  // Create a common system PSSH box.
555  ProtectionSystemSpecificInfo info;
556  info.set_system_id(kCommonSystemId, arraysize(kCommonSystemId));
557  info.set_pssh_box_version(1);
558  for (const std::vector<uint8_t>& key_id : key_ids) {
559  info.add_key_id(key_id);
560  }
561 
562  for (const EncryptionKeyMap::value_type& pair : encryption_key_map) {
563  pair.second->key_system_info.push_back(info);
564  }
565  }
566 
567  DCHECK(!encryption_key_map.empty());
568  if (!enable_key_rotation) {
569  encryption_key_map_.swap(encryption_key_map);
570  return true;
571  }
572  return PushToKeyPool(&encryption_key_map);
573 }
574 
575 bool WidevineKeySource::PushToKeyPool(
576  EncryptionKeyMap* encryption_key_map) {
577  DCHECK(key_pool_);
578  DCHECK(encryption_key_map);
579  auto encryption_key_map_shared = std::make_shared<EncryptionKeyMap>();
580  encryption_key_map_shared->swap(*encryption_key_map);
581  Status status = key_pool_->Push(encryption_key_map_shared, kInfiniteTimeout);
582  if (!status.ok()) {
583  DCHECK_EQ(error::STOPPED, status.error_code());
584  return false;
585  }
586  return true;
587 }
588 
589 } // namespace media
590 } // namespace shaka
Status FetchKeys(const std::vector< uint8_t > &pssh_box) override
Status GetCryptoPeriodKey(uint32_t crypto_period_index, TrackType track_type, EncryptionKey *key) override
Status GetKey(TrackType track_type, EncryptionKey *key) override
static std::string TrackTypeToString(TrackType track_type)
Convert TrackType to string.
Definition: key_source.cc:37
static TrackType GetTrackTypeFromString(const std::string &track_type_string)
Convert string representation of track type to enum representation.
Definition: key_source.cc:19
void set_key_fetcher(std::unique_ptr< KeyFetcher > key_fetcher)
WidevineKeySource(const std::string &server_url, bool add_common_pssh)
void set_signer(std::unique_ptr< RequestSigner > signer)
bool Parse(const uint8_t *data, size_t data_size)