Shaka Packager SDK
widevine_key_source.cc
1 // Copyright 2014 Google Inc. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include "packager/media/base/widevine_key_source.h"
8 
9 #include "packager/base/base64.h"
10 #include "packager/base/bind.h"
11 #include "packager/base/strings/string_number_conversions.h"
12 #include "packager/media/base/http_key_fetcher.h"
13 #include "packager/media/base/network_util.h"
14 #include "packager/media/base/producer_consumer_queue.h"
15 #include "packager/media/base/protection_system_specific_info.h"
16 #include "packager/media/base/proto_json_util.h"
17 #include "packager/media/base/pssh_generator_util.h"
18 #include "packager/media/base/rcheck.h"
19 #include "packager/media/base/request_signer.h"
20 #include "packager/media/base/widevine_common_encryption.pb.h"
21 #include "packager/media/base/widevine_pssh_generator.h"
22 
23 namespace shaka {
24 namespace media {
25 namespace {
26 
27 const bool kEnableKeyRotation = true;
28 
29 // Number of times to retry requesting keys in case of a transient error from
30 // the server.
31 const int kNumTransientErrorRetries = 5;
32 const int kFirstRetryDelayMilliseconds = 1000;
33 
34 // Default crypto period count, which is the number of keys to fetch on every
35 // key rotation enabled request.
36 const int kDefaultCryptoPeriodCount = 10;
37 const int kGetKeyTimeoutInSeconds = 5 * 60; // 5 minutes.
38 const int kKeyFetchTimeoutInSeconds = 60; // 1 minute.
39 
40 CommonEncryptionRequest::ProtectionScheme ToCommonEncryptionProtectionScheme(
41  FourCC protection_scheme) {
42  switch (protection_scheme) {
43  case FOURCC_cenc:
44  return CommonEncryptionRequest::CENC;
45  case FOURCC_cbcs:
46  case kAppleSampleAesProtectionScheme:
47  // Treat sample aes as a variant of cbcs.
48  return CommonEncryptionRequest::CBCS;
49  case FOURCC_cbc1:
50  return CommonEncryptionRequest::CBC1;
51  case FOURCC_cens:
52  return CommonEncryptionRequest::CENS;
53  default:
54  LOG(WARNING) << "Ignore unrecognized protection scheme "
55  << FourCCToString(protection_scheme);
56  return CommonEncryptionRequest::UNSPECIFIED;
57  }
58 }
59 
60 ProtectionSystemSpecificInfo ProtectionSystemInfoFromPsshProto(
61  const CommonEncryptionResponse::Track::Pssh& pssh_proto) {
62  PsshBoxBuilder pssh_builder;
63  pssh_builder.set_system_id(kWidevineSystemId, arraysize(kWidevineSystemId));
64 
65  if (pssh_proto.has_boxes()) {
66  return {pssh_builder.system_id(),
67  std::vector<uint8_t>(pssh_proto.boxes().begin(),
68  pssh_proto.boxes().end())};
69  } else {
70  pssh_builder.set_pssh_box_version(0);
71  const std::vector<uint8_t> pssh_data(pssh_proto.data().begin(),
72  pssh_proto.data().end());
73  pssh_builder.set_pssh_data(pssh_data);
74  return {pssh_builder.system_id(), pssh_builder.CreateBox()};
75  }
76 }
77 
78 } // namespace
79 
80 WidevineKeySource::WidevineKeySource(const std::string& server_url,
81  int protection_system_flags)
82  // Widevine PSSH is fetched from Widevine license server.
83  : KeySource(protection_system_flags & ~WIDEVINE_PROTECTION_SYSTEM_FLAG),
84  key_production_thread_("KeyProductionThread",
85  base::Bind(&WidevineKeySource::FetchKeysTask,
86  base::Unretained(this))),
87  key_fetcher_(new HttpKeyFetcher(kKeyFetchTimeoutInSeconds)),
88  server_url_(server_url),
89  crypto_period_count_(kDefaultCryptoPeriodCount),
90  protection_scheme_(FOURCC_cenc),
91  key_production_started_(false),
92  start_key_production_(base::WaitableEvent::ResetPolicy::AUTOMATIC,
93  base::WaitableEvent::InitialState::NOT_SIGNALED),
94  first_crypto_period_index_(0),
95  enable_entitlement_license_(false) {
96  key_production_thread_.Start();
97 }
98 
99 WidevineKeySource::~WidevineKeySource() {
100  if (key_pool_)
101  key_pool_->Stop();
102  if (key_production_thread_.HasBeenStarted()) {
103  // Signal the production thread to start key production if it is not
104  // signaled yet so the thread can be joined.
105  start_key_production_.Signal();
106  key_production_thread_.Join();
107  }
108 }
109 
110 Status WidevineKeySource::FetchKeys(const std::vector<uint8_t>& content_id,
111  const std::string& policy) {
112  base::AutoLock scoped_lock(lock_);
113  common_encryption_request_.reset(new CommonEncryptionRequest);
114  common_encryption_request_->set_content_id(content_id.data(),
115  content_id.size());
116  common_encryption_request_->set_policy(policy);
117  common_encryption_request_->set_protection_scheme(
118  ToCommonEncryptionProtectionScheme(protection_scheme_));
119  if (enable_entitlement_license_)
120  common_encryption_request_->set_enable_entitlement_license(true);
121 
122  return FetchKeysInternal(!kEnableKeyRotation, 0, false);
123 }
124 
125 Status WidevineKeySource::FetchKeys(EmeInitDataType init_data_type,
126  const std::vector<uint8_t>& init_data) {
127  std::vector<uint8_t> pssh_data;
128  uint32_t asset_id = 0;
129  switch (init_data_type) {
130  case EmeInitDataType::CENC: {
131  const std::vector<uint8_t> widevine_system_id(
132  kWidevineSystemId, kWidevineSystemId + arraysize(kWidevineSystemId));
133  std::vector<ProtectionSystemSpecificInfo> protection_systems_info;
135  init_data.data(), init_data.size(), &protection_systems_info)) {
136  return Status(error::PARSER_FAILURE, "Error parsing the PSSH boxes.");
137  }
138  for (const auto& info : protection_systems_info) {
139  std::unique_ptr<PsshBoxBuilder> pssh_builder =
140  PsshBoxBuilder::ParseFromBox(info.psshs.data(), info.psshs.size());
141  if (!pssh_builder)
142  return Status(error::PARSER_FAILURE, "Error parsing the PSSH box.");
143  // Use Widevine PSSH if available otherwise construct a Widevine PSSH
144  // from the first available key ids.
145  if (info.system_id == widevine_system_id) {
146  pssh_data = pssh_builder->pssh_data();
147  break;
148  } else if (pssh_data.empty() && !pssh_builder->key_ids().empty()) {
149  pssh_data =
150  GenerateWidevinePsshDataFromKeyIds(pssh_builder->key_ids());
151  // Continue to see if there is any Widevine PSSH. The KeyId generated
152  // PSSH is only used if a Widevine PSSH could not be found.
153  continue;
154  }
155  }
156  if (pssh_data.empty())
157  return Status(error::INVALID_ARGUMENT, "No supported PSSHs found.");
158  break;
159  }
160  case EmeInitDataType::WEBM: {
161  pssh_data = GenerateWidevinePsshDataFromKeyIds({init_data});
162  break;
163  }
164  case EmeInitDataType::WIDEVINE_CLASSIC:
165  if (init_data.size() < sizeof(asset_id))
166  return Status(error::INVALID_ARGUMENT, "Invalid asset id.");
167  asset_id = ntohlFromBuffer(init_data.data());
168  break;
169  default:
170  LOG(ERROR) << "Init data type " << static_cast<int>(init_data_type)
171  << " not supported.";
172  return Status(error::INVALID_ARGUMENT, "Unsupported init data type.");
173  }
174  const bool widevine_classic =
175  init_data_type == EmeInitDataType::WIDEVINE_CLASSIC;
176  base::AutoLock scoped_lock(lock_);
177  common_encryption_request_.reset(new CommonEncryptionRequest);
178  if (widevine_classic) {
179  common_encryption_request_->set_asset_id(asset_id);
180  } else {
181  common_encryption_request_->set_pssh_data(pssh_data.data(),
182  pssh_data.size());
183  }
184  return FetchKeysInternal(!kEnableKeyRotation, 0, widevine_classic);
185 }
186 
187 Status WidevineKeySource::GetKey(const std::string& stream_label,
188  EncryptionKey* key) {
189  DCHECK(key);
190  if (encryption_key_map_.find(stream_label) == encryption_key_map_.end()) {
191  return Status(error::INTERNAL_ERROR,
192  "Cannot find key for '" + stream_label + "'.");
193  }
194  *key = *encryption_key_map_[stream_label];
195  return Status::OK;
196 }
197 
198 Status WidevineKeySource::GetKey(const std::vector<uint8_t>& key_id,
199  EncryptionKey* key) {
200  DCHECK(key);
201  for (const auto& pair : encryption_key_map_) {
202  if (pair.second->key_id == key_id) {
203  *key = *pair.second;
204  return Status::OK;
205  }
206  }
207  return Status(error::INTERNAL_ERROR,
208  "Cannot find key with specified key ID");
209 }
210 
211 Status WidevineKeySource::GetCryptoPeriodKey(uint32_t crypto_period_index,
212  const std::string& stream_label,
213  EncryptionKey* key) {
214  DCHECK(key_production_thread_.HasBeenStarted());
215  // TODO(kqyang): This is not elegant. Consider refactoring later.
216  {
217  base::AutoLock scoped_lock(lock_);
218  if (!key_production_started_) {
219  // Another client may have a slightly smaller starting crypto period
220  // index. Set the initial value to account for that.
221  first_crypto_period_index_ =
222  crypto_period_index ? crypto_period_index - 1 : 0;
223  DCHECK(!key_pool_);
224  key_pool_.reset(new EncryptionKeyQueue(crypto_period_count_,
225  first_crypto_period_index_));
226  start_key_production_.Signal();
227  key_production_started_ = true;
228  }
229  }
230  return GetKeyInternal(crypto_period_index, stream_label, key);
231 }
232 
233 void WidevineKeySource::set_signer(std::unique_ptr<RequestSigner> signer) {
234  signer_ = std::move(signer);
235 }
236 
238  std::unique_ptr<KeyFetcher> key_fetcher) {
239  key_fetcher_ = std::move(key_fetcher);
240 }
241 
242 Status WidevineKeySource::GetKeyInternal(uint32_t crypto_period_index,
243  const std::string& stream_label,
244  EncryptionKey* key) {
245  DCHECK(key_pool_);
246  DCHECK(key);
247 
248  std::shared_ptr<EncryptionKeyMap> encryption_key_map;
249  Status status = key_pool_->Peek(crypto_period_index, &encryption_key_map,
250  kGetKeyTimeoutInSeconds * 1000);
251  if (!status.ok()) {
252  if (status.error_code() == error::STOPPED) {
253  CHECK(!common_encryption_request_status_.ok());
254  return common_encryption_request_status_;
255  }
256  return status;
257  }
258 
259  if (encryption_key_map->find(stream_label) == encryption_key_map->end()) {
260  return Status(error::INTERNAL_ERROR,
261  "Cannot find key for '" + stream_label + "'.");
262  }
263  *key = *encryption_key_map->at(stream_label);
264  return Status::OK;
265 }
266 
267 void WidevineKeySource::FetchKeysTask() {
268  // Wait until key production is signaled.
269  start_key_production_.Wait();
270  if (!key_pool_ || key_pool_->Stopped())
271  return;
272 
273  Status status = FetchKeysInternal(kEnableKeyRotation,
274  first_crypto_period_index_,
275  false);
276  while (status.ok()) {
277  first_crypto_period_index_ += crypto_period_count_;
278  status = FetchKeysInternal(kEnableKeyRotation,
279  first_crypto_period_index_,
280  false);
281  }
282  common_encryption_request_status_ = status;
283  key_pool_->Stop();
284 }
285 
286 Status WidevineKeySource::FetchKeysInternal(bool enable_key_rotation,
287  uint32_t first_crypto_period_index,
288  bool widevine_classic) {
289  CommonEncryptionRequest request;
290  FillRequest(enable_key_rotation, first_crypto_period_index, &request);
291 
292  std::string message;
293  Status status = GenerateKeyMessage(request, &message);
294  if (!status.ok())
295  return status;
296  VLOG(1) << "Message: " << message;
297 
298  std::string raw_response;
299  int64_t sleep_duration = kFirstRetryDelayMilliseconds;
300 
301  // Perform client side retries if seeing server transient error to workaround
302  // server limitation.
303  for (int i = 0; i < kNumTransientErrorRetries; ++i) {
304  status = key_fetcher_->FetchKeys(server_url_, message, &raw_response);
305  if (status.ok()) {
306  VLOG(1) << "Retry [" << i << "] Response:" << raw_response;
307 
308  bool transient_error = false;
309  if (ExtractEncryptionKey(enable_key_rotation, widevine_classic,
310  raw_response, &transient_error))
311  return Status::OK;
312 
313  if (!transient_error) {
314  return Status(
315  error::SERVER_ERROR,
316  "Failed to extract encryption key from '" + raw_response + "'.");
317  }
318  } else if (status.error_code() != error::TIME_OUT) {
319  return status;
320  }
321 
322  // Exponential backoff.
323  if (i != kNumTransientErrorRetries - 1) {
324  base::PlatformThread::Sleep(
325  base::TimeDelta::FromMilliseconds(sleep_duration));
326  sleep_duration *= 2;
327  }
328  }
329  return Status(error::SERVER_ERROR,
330  "Failed to recover from server internal error.");
331 }
332 
333 void WidevineKeySource::FillRequest(bool enable_key_rotation,
334  uint32_t first_crypto_period_index,
335  CommonEncryptionRequest* request) {
336  DCHECK(common_encryption_request_);
337  DCHECK(request);
338  *request = *common_encryption_request_;
339 
340  request->add_tracks()->set_type("SD");
341  request->add_tracks()->set_type("HD");
342  request->add_tracks()->set_type("UHD1");
343  request->add_tracks()->set_type("UHD2");
344  request->add_tracks()->set_type("AUDIO");
345 
346  request->add_drm_types(ModularDrmType::WIDEVINE);
347 
348  if (enable_key_rotation) {
349  request->set_first_crypto_period_index(first_crypto_period_index);
350  request->set_crypto_period_count(crypto_period_count_);
351  }
352 
353  if (!group_id_.empty())
354  request->set_group_id(group_id_.data(), group_id_.size());
355 }
356 
357 Status WidevineKeySource::GenerateKeyMessage(
358  const CommonEncryptionRequest& request,
359  std::string* message) {
360  DCHECK(message);
361 
362  SignedModularDrmRequest signed_request;
363  signed_request.set_request(MessageToJsonString(request));
364 
365  // Sign the request.
366  if (signer_) {
367  std::string signature;
368  if (!signer_->GenerateSignature(signed_request.request(), &signature))
369  return Status(error::INTERNAL_ERROR, "Signature generation failed.");
370 
371  signed_request.set_signature(signature);
372  signed_request.set_signer(signer_->signer_name());
373  }
374 
375  *message = MessageToJsonString(signed_request);
376  return Status::OK;
377 }
378 
379 bool WidevineKeySource::ExtractEncryptionKey(
380  bool enable_key_rotation,
381  bool widevine_classic,
382  const std::string& response,
383  bool* transient_error) {
384  DCHECK(transient_error);
385  *transient_error = false;
386 
387  SignedModularDrmResponse signed_response_proto;
388  if (!JsonStringToMessage(response, &signed_response_proto)) {
389  LOG(ERROR) << "Failed to convert JSON to proto: " << response;
390  return false;
391  }
392 
393  CommonEncryptionResponse response_proto;
394  if (!JsonStringToMessage(signed_response_proto.response(), &response_proto)) {
395  LOG(ERROR) << "Failed to convert JSON to proto: "
396  << signed_response_proto.response();
397  return false;
398  }
399 
400  if (response_proto.status() != CommonEncryptionResponse::OK) {
401  LOG(ERROR) << "Received non-OK license response: " << response;
402  // Server may return INTERNAL_ERROR intermittently, which is a transient
403  // error and the next client request may succeed without problem.
404  *transient_error =
405  (response_proto.status() == CommonEncryptionResponse::INTERNAL_ERROR);
406  return false;
407  }
408 
409  RCHECK(enable_key_rotation
410  ? response_proto.tracks_size() >= crypto_period_count_
411  : response_proto.tracks_size() >= 1);
412 
413  uint32_t current_crypto_period_index = first_crypto_period_index_;
414 
415  EncryptionKeyMap encryption_key_map;
416  for (const auto& track : response_proto.tracks()) {
417  VLOG(2) << "track " << track.ShortDebugString();
418 
419  if (enable_key_rotation) {
420  if (track.crypto_period_index() != current_crypto_period_index) {
421  if (track.crypto_period_index() != current_crypto_period_index + 1) {
422  LOG(ERROR) << "Expecting crypto period index "
423  << current_crypto_period_index << " or "
424  << current_crypto_period_index + 1 << "; Seen "
425  << track.crypto_period_index();
426  return false;
427  }
428  if (!PushToKeyPool(&encryption_key_map))
429  return false;
430  ++current_crypto_period_index;
431  }
432  }
433 
434  const std::string& stream_label = track.type();
435  RCHECK(encryption_key_map.find(stream_label) == encryption_key_map.end());
436 
437  std::unique_ptr<EncryptionKey> encryption_key(new EncryptionKey());
438  encryption_key->key.assign(track.key().begin(), track.key().end());
439 
440  // Get key ID and PSSH data for CENC content only.
441  if (!widevine_classic) {
442  encryption_key->key_id.assign(track.key_id().begin(),
443  track.key_id().end());
444 
445  if (track.pssh_size() != 1) {
446  LOG(ERROR) << "Expecting one and only one pssh, seeing "
447  << track.pssh_size();
448  return false;
449  }
450  encryption_key->key_system_info.push_back(
451  ProtectionSystemInfoFromPsshProto(track.pssh(0)));
452  }
453  encryption_key_map[stream_label] = std::move(encryption_key);
454  }
455 
456  if (!widevine_classic) {
457  if (!UpdateProtectionSystemInfo(&encryption_key_map).ok()) {
458  return false;
459  }
460  }
461 
462  DCHECK(!encryption_key_map.empty());
463  if (!enable_key_rotation) {
464  // Merge with previously requested keys.
465  for (auto& pair : encryption_key_map)
466  encryption_key_map_[pair.first] = std::move(pair.second);
467  return true;
468  }
469  return PushToKeyPool(&encryption_key_map);
470 }
471 
472 bool WidevineKeySource::PushToKeyPool(
473  EncryptionKeyMap* encryption_key_map) {
474  DCHECK(key_pool_);
475  DCHECK(encryption_key_map);
476  auto encryption_key_map_shared = std::make_shared<EncryptionKeyMap>();
477  encryption_key_map_shared->swap(*encryption_key_map);
478  Status status = key_pool_->Push(encryption_key_map_shared, kInfiniteTimeout);
479  if (!status.ok()) {
480  DCHECK_EQ(error::STOPPED, status.error_code());
481  return false;
482  }
483  return true;
484 }
485 
486 } // namespace media
487 } // namespace shaka
All the methods that are virtual are virtual for mocking.
Status UpdateProtectionSystemInfo(EncryptionKeyMap *encryption_key_map)
Definition: key_source.cc:37
WidevineKeySource(const std::string &server_url, int protection_systems_flags)
Status GetKey(const std::string &stream_label, EncryptionKey *key) override
void set_key_fetcher(std::unique_ptr< KeyFetcher > key_fetcher)
static std::unique_ptr< PsshBoxBuilder > ParseFromBox(const uint8_t *data, size_t data_size)
Status GetCryptoPeriodKey(uint32_t crypto_period_index, const std::string &stream_label, EncryptionKey *key) override
void set_signer(std::unique_ptr< RequestSigner > signer)
KeySource is responsible for encryption key acquisition.
Definition: key_source.h:50
Status FetchKeys(EmeInitDataType init_data_type, const std::vector< uint8_t > &init_data) override
static bool ParseBoxes(const uint8_t *data, size_t data_size, std::vector< ProtectionSystemSpecificInfo > *pssh_boxes)