Shaka Packager SDK
widevine_key_source.cc
1 // Copyright 2014 Google Inc. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include "packager/media/base/widevine_key_source.h"
8 
9 #include "packager/base/base64.h"
10 #include "packager/base/bind.h"
11 #include "packager/base/strings/string_number_conversions.h"
12 #include "packager/media/base/http_key_fetcher.h"
13 #include "packager/media/base/network_util.h"
14 #include "packager/media/base/producer_consumer_queue.h"
15 #include "packager/media/base/protection_system_ids.h"
16 #include "packager/media/base/protection_system_specific_info.h"
17 #include "packager/media/base/proto_json_util.h"
18 #include "packager/media/base/pssh_generator_util.h"
19 #include "packager/media/base/rcheck.h"
20 #include "packager/media/base/request_signer.h"
21 #include "packager/media/base/widevine_common_encryption.pb.h"
22 
23 namespace shaka {
24 namespace media {
25 namespace {
26 
27 const bool kEnableKeyRotation = true;
28 
29 // Number of times to retry requesting keys in case of a transient error from
30 // the server.
31 const int kNumTransientErrorRetries = 5;
32 const int kFirstRetryDelayMilliseconds = 1000;
33 
34 // Default crypto period count, which is the number of keys to fetch on every
35 // key rotation enabled request.
36 const int kDefaultCryptoPeriodCount = 10;
37 const int kGetKeyTimeoutInSeconds = 5 * 60; // 5 minutes.
38 const int kKeyFetchTimeoutInSeconds = 60; // 1 minute.
39 
40 CommonEncryptionRequest::ProtectionScheme ToCommonEncryptionProtectionScheme(
41  FourCC protection_scheme) {
42  switch (protection_scheme) {
43  case FOURCC_cenc:
44  return CommonEncryptionRequest::CENC;
45  case FOURCC_cbcs:
46  case kAppleSampleAesProtectionScheme:
47  // Treat sample aes as a variant of cbcs.
48  return CommonEncryptionRequest::CBCS;
49  case FOURCC_cbc1:
50  return CommonEncryptionRequest::CBC1;
51  case FOURCC_cens:
52  return CommonEncryptionRequest::CENS;
53  default:
54  LOG(WARNING) << "Ignore unrecognized protection scheme "
55  << FourCCToString(protection_scheme);
56  return CommonEncryptionRequest::UNSPECIFIED;
57  }
58 }
59 
60 ProtectionSystemSpecificInfo ProtectionSystemInfoFromPsshProto(
61  const CommonEncryptionResponse::Track::Pssh& pssh_proto) {
62  PsshBoxBuilder pssh_builder;
63  pssh_builder.set_system_id(kWidevineSystemId, arraysize(kWidevineSystemId));
64 
65  if (pssh_proto.has_boxes()) {
66  return {pssh_builder.system_id(),
67  std::vector<uint8_t>(pssh_proto.boxes().begin(),
68  pssh_proto.boxes().end())};
69  } else {
70  pssh_builder.set_pssh_box_version(0);
71  const std::vector<uint8_t> pssh_data(pssh_proto.data().begin(),
72  pssh_proto.data().end());
73  pssh_builder.set_pssh_data(pssh_data);
74  return {pssh_builder.system_id(), pssh_builder.CreateBox()};
75  }
76 }
77 
78 } // namespace
79 
80 WidevineKeySource::WidevineKeySource(const std::string& server_url,
81  int protection_system_flags,
82  FourCC protection_scheme)
83  // Widevine PSSH is fetched from Widevine license server.
84  : KeySource(protection_system_flags & ~WIDEVINE_PROTECTION_SYSTEM_FLAG,
85  protection_scheme),
86  generate_widevine_protection_system_(
87  // Generate Widevine protection system if there are no other
88  // protection system specified.
89  protection_system_flags == NO_PROTECTION_SYSTEM_FLAG ||
90  protection_system_flags & WIDEVINE_PROTECTION_SYSTEM_FLAG),
91  key_production_thread_("KeyProductionThread",
92  base::Bind(&WidevineKeySource::FetchKeysTask,
93  base::Unretained(this))),
94  key_fetcher_(new HttpKeyFetcher(kKeyFetchTimeoutInSeconds)),
95  server_url_(server_url),
96  crypto_period_count_(kDefaultCryptoPeriodCount),
97  protection_scheme_(protection_scheme),
98  start_key_production_(base::WaitableEvent::ResetPolicy::AUTOMATIC,
99  base::WaitableEvent::InitialState::NOT_SIGNALED) {
100  key_production_thread_.Start();
101 }
102 
103 WidevineKeySource::~WidevineKeySource() {
104  if (key_pool_)
105  key_pool_->Stop();
106  if (key_production_thread_.HasBeenStarted()) {
107  // Signal the production thread to start key production if it is not
108  // signaled yet so the thread can be joined.
109  start_key_production_.Signal();
110  key_production_thread_.Join();
111  }
112 }
113 
114 Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& content_id,
115  const std::string& policy) {
116  base::AutoLock scoped_lock(lock_);
117  common_encryption_request_.reset(new CommonEncryptionRequest);
118  common_encryption_request_->set_content_id(content_id.data(),
119  content_id.size());
120  common_encryption_request_->set_policy(policy);
121  common_encryption_request_->set_protection_scheme(
122  ToCommonEncryptionProtectionScheme(protection_scheme_));
123  if (enable_entitlement_license_)
124  common_encryption_request_->set_enable_entitlement_license(true);
125 
126  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
127 }
128 
129 Status WidevineKeySource::FetchKeys(EmeInitDataType init_data_type,
130  const std::vector<uint8_t>& init_data) {
131  std::vector<uint8_t> pssh_data;
132  uint32_t asset_id = 0;
133  switch (init_data_type) {
134  case EmeInitDataType::CENC: {
135  const std::vector<uint8_t> widevine_system_id(
136  kWidevineSystemId, kWidevineSystemId + arraysize(kWidevineSystemId));
137  std::vector<ProtectionSystemSpecificInfo> protection_systems_info;
139  init_data.data(), init_data.size(), &protection_systems_info)) {
140  return Status(error::PARSER_FAILURE, "Error parsing the PSSH boxes.");
141  }
142  for (const auto& info : protection_systems_info) {
143  std::unique_ptr<PsshBoxBuilder> pssh_builder =
144  PsshBoxBuilder::ParseFromBox(info.psshs.data(), info.psshs.size());
145  if (!pssh_builder)
146  return Status(error::PARSER_FAILURE, "Error parsing the PSSH box.");
147  // Use Widevine PSSH if available otherwise construct a Widevine PSSH
148  // from the first available key ids.
149  if (info.system_id == widevine_system_id) {
150  pssh_data = pssh_builder->pssh_data();
151  break;
152  } else if (pssh_data.empty() && !pssh_builder->key_ids().empty()) {
153  pssh_data =
154  GenerateWidevinePsshDataFromKeyIds(pssh_builder->key_ids());
155  // Continue to see if there is any Widevine PSSH. The KeyId generated
156  // PSSH is only used if a Widevine PSSH could not be found.
157  continue;
158  }
159  }
160  if (pssh_data.empty())
161  return Status(error::INVALID_ARGUMENT, "No supported PSSHs found.");
162  break;
163  }
164  case EmeInitDataType::WEBM: {
165  pssh_data = GenerateWidevinePsshDataFromKeyIds({init_data});
166  break;
167  }
168  case EmeInitDataType::WIDEVINE_CLASSIC:
169  if (init_data.size() < sizeof(asset_id))
170  return Status(error::INVALID_ARGUMENT, "Invalid asset id.");
171  asset_id = ntohlFromBuffer(init_data.data());
172  break;
173  default:
174  LOG(ERROR) << "Init data type " << static_cast<int>(init_data_type)
175  << " not supported.";
176  return Status(error::INVALID_ARGUMENT, "Unsupported init data type.");
177  }
178  const bool widevine_classic =
179  init_data_type == EmeInitDataType::WIDEVINE_CLASSIC;
180  base::AutoLock scoped_lock(lock_);
181  common_encryption_request_.reset(new CommonEncryptionRequest);
182  if (widevine_classic) {
183  common_encryption_request_->set_asset_id(asset_id);
184  } else {
185  common_encryption_request_->set_pssh_data(pssh_data.data(),
186  pssh_data.size());
187  }
188  return FetchKeysInternal(!kEnableKeyRotation, 0, widevine_classic);
189 }
190 
191 Status WidevineKeySource::GetKey(const std::string& stream_label,
192  EncryptionKey* key) {
193  DCHECK(key);
194  if (encryption_key_map_.find(stream_label) == encryption_key_map_.end()) {
195  return Status(error::INTERNAL_ERROR,
196  "Cannot find key for '" + stream_label + "'.");
197  }
198  *key = *encryption_key_map_[stream_label];
199  return Status::OK;
200 }
201 
202 Status WidevineKeySource::GetKey(const std::vector<uint8_t>& key_id,
203  EncryptionKey* key) {
204  DCHECK(key);
205  for (const auto& pair : encryption_key_map_) {
206  if (pair.second->key_id == key_id) {
207  *key = *pair.second;
208  return Status::OK;
209  }
210  }
211  return Status(error::INTERNAL_ERROR,
212  "Cannot find key with specified key ID");
213 }
214 
215 Status WidevineKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
216  const std::string& stream_label,
217  EncryptionKey* key) {
218  DCHECK(key_production_thread_.HasBeenStarted());
219  // TODO(kqyang): This is not elegant. Consider refactoring later.
220  {
221  base::AutoLock scoped_lock(lock_);
222  if (!key_production_started_) {
223  // Another client may have a slightly smaller starting crypto period
224  // index. Set the initial value to account for that.
225  first_crypto_period_index_ =
226  crypto_period_index ? crypto_period_index - 1 : 0;
227  DCHECK(!key_pool_);
228  const size_t queue_size = crypto_period_count_ * 10;
229  key_pool_.reset(
230  new EncryptionKeyQueue(queue_size, first_crypto_period_index_));
231  start_key_production_.Signal();
232  key_production_started_ = true;
233  }
234  }
235  return GetKeyInternal(crypto_period_index, stream_label, key);
236 }
237 
238 void WidevineKeySource::set_signer(std::unique_ptr<RequestSigner> signer) {
239  signer_ = std::move(signer);
240 }
241 
243  std::unique_ptr<KeyFetcher> key_fetcher) {
244  key_fetcher_ = std::move(key_fetcher);
245 }
246 
247 Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
248  const std::string& stream_label,
249  EncryptionKey* key) {
250  DCHECK(key_pool_);
251  DCHECK(key);
252 
253  std::shared_ptr<EncryptionKeyMap> encryption_key_map;
254  Status status = key_pool_->Peek(crypto_period_index, &encryption_key_map,
255  kGetKeyTimeoutInSeconds * 1000);
256  if (!status.ok()) {
257  if (status.error_code() == error::STOPPED) {
258  CHECK(!common_encryption_request_status_.ok());
259  return common_encryption_request_status_;
260  }
261  return status;
262  }
263 
264  if (encryption_key_map->find(stream_label) == encryption_key_map->end()) {
265  return Status(error::INTERNAL_ERROR,
266  "Cannot find key for '" + stream_label + "'.");
267  }
268  *key = *encryption_key_map->at(stream_label);
269  return Status::OK;
270 }
271 
272 void WidevineKeySource::FetchKeysTask() {
273  // Wait until key production is signaled.
274  start_key_production_.Wait();
275  if (!key_pool_ || key_pool_->Stopped())
276  return;
277 
278  Status status = FetchKeysInternal(kEnableKeyRotation,
279  first_crypto_period_index_,
280  false);
281  while (status.ok()) {
282  first_crypto_period_index_ += crypto_period_count_;
283  status = FetchKeysInternal(kEnableKeyRotation,
284  first_crypto_period_index_,
285  false);
286  }
287  common_encryption_request_status_ = status;
288  key_pool_->Stop();
289 }
290 
291 Status WidevineKeySource::FetchKeysInternal(bool enable_key_rotation,
292  uint32_t first_crypto_period_index,
293  bool widevine_classic) {
294  CommonEncryptionRequest request;
295  FillRequest(enable_key_rotation, first_crypto_period_index, &request);
296 
297  std::string message;
298  Status status = GenerateKeyMessage(request, &message);
299  if (!status.ok())
300  return status;
301  VLOG(1) << "Message: " << message;
302 
303  std::string raw_response;
304  int64_t sleep_duration = kFirstRetryDelayMilliseconds;
305 
306  // Perform client side retries if seeing server transient error to workaround
307  // server limitation.
308  for (int i = 0; i < kNumTransientErrorRetries; ++i) {
309  status = key_fetcher_->FetchKeys(server_url_, message, &raw_response);
310  if (status.ok()) {
311  VLOG(1) << "Retry [" << i << "] Response:" << raw_response;
312 
313  bool transient_error = false;
314  if (ExtractEncryptionKey(enable_key_rotation, widevine_classic,
315  raw_response, &transient_error))
316  return Status::OK;
317 
318  if (!transient_error) {
319  return Status(
320  error::SERVER_ERROR,
321  "Failed to extract encryption key from '" + raw_response + "'.");
322  }
323  } else if (status.error_code() != error::TIME_OUT) {
324  return status;
325  }
326 
327  // Exponential backoff.
328  if (i != kNumTransientErrorRetries - 1) {
329  base::PlatformThread::Sleep(
330  base::TimeDelta::FromMilliseconds(sleep_duration));
331  sleep_duration *= 2;
332  }
333  }
334  return Status(error::SERVER_ERROR,
335  "Failed to recover from server internal error.");
336 }
337 
338 void WidevineKeySource::FillRequest(bool enable_key_rotation,
339  uint32_t first_crypto_period_index,
340  CommonEncryptionRequest* request) {
341  DCHECK(common_encryption_request_);
342  DCHECK(request);
343  *request = *common_encryption_request_;
344 
345  request->add_tracks()->set_type("SD");
346  request->add_tracks()->set_type("HD");
347  request->add_tracks()->set_type("UHD1");
348  request->add_tracks()->set_type("UHD2");
349  request->add_tracks()->set_type("AUDIO");
350 
351  request->add_drm_types(ModularDrmType::WIDEVINE);
352 
353  if (enable_key_rotation) {
354  request->set_first_crypto_period_index(first_crypto_period_index);
355  request->set_crypto_period_count(crypto_period_count_);
356  }
357 
358  if (!group_id_.empty())
359  request->set_group_id(group_id_.data(), group_id_.size());
360 }
361 
362 Status WidevineKeySource::GenerateKeyMessage(
363  const CommonEncryptionRequest& request,
364  std::string* message) {
365  DCHECK(message);
366 
367  SignedModularDrmRequest signed_request;
368  signed_request.set_request(MessageToJsonString(request));
369 
370  // Sign the request.
371  if (signer_) {
372  std::string signature;
373  if (!signer_->GenerateSignature(signed_request.request(), &signature))
374  return Status(error::INTERNAL_ERROR, "Signature generation failed.");
375 
376  signed_request.set_signature(signature);
377  signed_request.set_signer(signer_->signer_name());
378  }
379 
380  *message = MessageToJsonString(signed_request);
381  return Status::OK;
382 }
383 
384 bool WidevineKeySource::ExtractEncryptionKey(
385  bool enable_key_rotation,
386  bool widevine_classic,
387  const std::string& response,
388  bool* transient_error) {
389  DCHECK(transient_error);
390  *transient_error = false;
391 
392  SignedModularDrmResponse signed_response_proto;
393  if (!JsonStringToMessage(response, &signed_response_proto)) {
394  LOG(ERROR) << "Failed to convert JSON to proto: " << response;
395  return false;
396  }
397 
398  CommonEncryptionResponse response_proto;
399  if (!JsonStringToMessage(signed_response_proto.response(), &response_proto)) {
400  LOG(ERROR) << "Failed to convert JSON to proto: "
401  << signed_response_proto.response();
402  return false;
403  }
404 
405  if (response_proto.status() != CommonEncryptionResponse::OK) {
406  LOG(ERROR) << "Received non-OK license response: " << response;
407  // Server may return INTERNAL_ERROR intermittently, which is a transient
408  // error and the next client request may succeed without problem.
409  *transient_error =
410  (response_proto.status() == CommonEncryptionResponse::INTERNAL_ERROR);
411  return false;
412  }
413 
414  RCHECK(enable_key_rotation
415  ? response_proto.tracks_size() >= crypto_period_count_
416  : response_proto.tracks_size() >= 1);
417 
418  uint32_t current_crypto_period_index = first_crypto_period_index_;
419 
420  EncryptionKeyMap encryption_key_map;
421  for (const auto& track : response_proto.tracks()) {
422  VLOG(2) << "track " << track.ShortDebugString();
423 
424  if (enable_key_rotation) {
425  if (track.crypto_period_index() != current_crypto_period_index) {
426  if (track.crypto_period_index() != current_crypto_period_index + 1) {
427  LOG(ERROR) << "Expecting crypto period index "
428  << current_crypto_period_index << " or "
429  << current_crypto_period_index + 1 << "; Seen "
430  << track.crypto_period_index();
431  return false;
432  }
433  if (!PushToKeyPool(&encryption_key_map))
434  return false;
435  ++current_crypto_period_index;
436  }
437  }
438 
439  const std::string& stream_label = track.type();
440  RCHECK(encryption_key_map.find(stream_label) == encryption_key_map.end());
441 
442  std::unique_ptr<EncryptionKey> encryption_key(new EncryptionKey());
443  encryption_key->key.assign(track.key().begin(), track.key().end());
444 
445  // Get key ID and PSSH data for CENC content only.
446  if (!widevine_classic) {
447  encryption_key->key_id.assign(track.key_id().begin(),
448  track.key_id().end());
449 
450  if (generate_widevine_protection_system_) {
451  if (track.pssh_size() != 1) {
452  LOG(ERROR) << "Expecting one and only one pssh, seeing "
453  << track.pssh_size();
454  return false;
455  }
456  encryption_key->key_system_info.push_back(
457  ProtectionSystemInfoFromPsshProto(track.pssh(0)));
458  }
459  }
460  encryption_key_map[stream_label] = std::move(encryption_key);
461  }
462 
463  if (!widevine_classic) {
464  if (!UpdateProtectionSystemInfo(&encryption_key_map).ok()) {
465  return false;
466  }
467  }
468 
469  DCHECK(!encryption_key_map.empty());
470  if (!enable_key_rotation) {
471  // Merge with previously requested keys.
472  for (auto& pair : encryption_key_map)
473  encryption_key_map_[pair.first] = std::move(pair.second);
474  return true;
475  }
476  return PushToKeyPool(&encryption_key_map);
477 }
478 
479 bool WidevineKeySource::PushToKeyPool(
480  EncryptionKeyMap* encryption_key_map) {
481  DCHECK(key_pool_);
482  DCHECK(encryption_key_map);
483  auto encryption_key_map_shared = std::make_shared<EncryptionKeyMap>();
484  encryption_key_map_shared->swap(*encryption_key_map);
485  Status status = key_pool_->Push(encryption_key_map_shared, kInfiniteTimeout);
486  if (!status.ok()) {
487  DCHECK_EQ(error::STOPPED, status.error_code());
488  return false;
489  }
490  return true;
491 }
492 
493 } // namespace media
494 } // namespace shaka
All the methods that are virtual are virtual for mocking.
Status UpdateProtectionSystemInfo(EncryptionKeyMap *encryption_key_map)
Definition: key_source.cc:47
Status GetKey(const std::string &stream_label, EncryptionKey *key) override
void set_key_fetcher(std::unique_ptr< KeyFetcher > key_fetcher)
WidevineKeySource(const std::string &server_url, int protection_systems_flags, FourCC protection_scheme)
static std::unique_ptr< PsshBoxBuilder > ParseFromBox(const uint8_t *data, size_t data_size)
Status GetCryptoPeriodKey(uint32_t crypto_period_index, const std::string &stream_label, EncryptionKey *key) override
void set_signer(std::unique_ptr< RequestSigner > signer)
KeySource is responsible for encryption key acquisition.
Definition: key_source.h:48
Status FetchKeys(EmeInitDataType init_data_type, const std::vector< uint8_t > &init_data) override
static bool ParseBoxes(const uint8_t *data, size_t data_size, std::vector< ProtectionSystemSpecificInfo > *pssh_boxes)