Shaka Packager SDK
encryption_handler.cc
1 // Copyright 2017 Google Inc. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include "packager/media/crypto/encryption_handler.h"
8 
9 #include <stddef.h>
10 #include <stdint.h>
11 
12 #include <algorithm>
13 
14 #include "packager/media/base/aes_encryptor.h"
15 #include "packager/media/base/audio_stream_info.h"
16 #include "packager/media/base/common_pssh_generator.h"
17 #include "packager/media/base/key_source.h"
18 #include "packager/media/base/macros.h"
19 #include "packager/media/base/media_sample.h"
20 #include "packager/media/base/playready_pssh_generator.h"
21 #include "packager/media/base/protection_system_ids.h"
22 #include "packager/media/base/video_stream_info.h"
23 #include "packager/media/base/widevine_pssh_generator.h"
24 #include "packager/media/crypto/aes_encryptor_factory.h"
25 #include "packager/media/crypto/subsample_generator.h"
26 #include "packager/status_macros.h"
27 
28 namespace shaka {
29 namespace media {
30 
31 namespace {
32 // The encryption handler only supports a single output.
33 const size_t kStreamIndex = 0;
34 
35 // The default KID, KEY and IV for key rotation are all 0s.
36 // They are placeholders and are not really being used to encrypt data.
37 const uint8_t kKeyRotationDefaultKeyId[] = {
38  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
39 };
40 const uint8_t kKeyRotationDefaultKey[] = {
41  0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
42 };
43 const uint8_t kKeyRotationDefaultIv[] = {
44  0, 0, 0, 0, 0, 0, 0, 0,
45 };
46 
47 std::string GetStreamLabelForEncryption(
48  const StreamInfo& stream_info,
49  const std::function<std::string(
50  const EncryptionParams::EncryptedStreamAttributes& stream_attributes)>&
51  stream_label_func) {
52  EncryptionParams::EncryptedStreamAttributes stream_attributes;
53  if (stream_info.stream_type() == kStreamAudio) {
54  stream_attributes.stream_type =
55  EncryptionParams::EncryptedStreamAttributes::kAudio;
56  } else if (stream_info.stream_type() == kStreamVideo) {
57  const VideoStreamInfo& video_stream_info =
58  static_cast<const VideoStreamInfo&>(stream_info);
59  stream_attributes.stream_type =
60  EncryptionParams::EncryptedStreamAttributes::kVideo;
61  stream_attributes.oneof.video.width = video_stream_info.width();
62  stream_attributes.oneof.video.height = video_stream_info.height();
63  }
64  return stream_label_func(stream_attributes);
65 }
66 
67 bool IsPatternEncryptionScheme(FourCC protection_scheme) {
68  return protection_scheme == kAppleSampleAesProtectionScheme ||
69  protection_scheme == FOURCC_cbcs || protection_scheme == FOURCC_cens;
70 }
71 
72 void FillPsshGenerators(
73  const EncryptionParams& encryption_params,
74  std::vector<std::unique_ptr<PsshGenerator>>* pssh_generators,
75  std::vector<std::vector<uint8_t>>* no_pssh_systems) {
76  if (has_flag(encryption_params.protection_systems,
78  pssh_generators->emplace_back(new CommonPsshGenerator());
79  }
80 
81  if (has_flag(encryption_params.protection_systems,
82  ProtectionSystem::kPlayReady)) {
83  pssh_generators->emplace_back(new PlayReadyPsshGenerator(
84  encryption_params.playready_extra_header_data,
85  static_cast<FourCC>(encryption_params.protection_scheme)));
86  }
87 
88  if (has_flag(encryption_params.protection_systems,
89  ProtectionSystem::kWidevine)) {
90  pssh_generators->emplace_back(new WidevinePsshGenerator(
91  static_cast<FourCC>(encryption_params.protection_scheme)));
92  }
93 
94  if (has_flag(encryption_params.protection_systems,
95  ProtectionSystem::kFairPlay)) {
96  no_pssh_systems->emplace_back(std::begin(kFairPlaySystemId),
97  std::end(kFairPlaySystemId));
98  }
99  // We only support Marlin Adaptive Streaming Specification – Simple Profile
100  // with Implicit Content ID Mapping, which does not need a PSSH. Marlin
101  // specific PSSH with Explicit Content ID Mapping is not generated.
102  if (has_flag(encryption_params.protection_systems,
103  ProtectionSystem::kMarlin)) {
104  no_pssh_systems->emplace_back(std::begin(kMarlinSystemId),
105  std::end(kMarlinSystemId));
106  }
107 
108  if (pssh_generators->empty() && no_pssh_systems->empty() &&
109  (encryption_params.key_provider != KeyProvider::kRawKey ||
110  encryption_params.raw_key.pssh.empty())) {
111  pssh_generators->emplace_back(new CommonPsshGenerator());
112  }
113 }
114 
115 void AddProtectionSystemIfNotExist(
116  const ProtectionSystemSpecificInfo& pssh_info,
117  EncryptionConfig* encryption_config) {
118  for (const auto& info : encryption_config->key_system_info) {
119  if (info.system_id == pssh_info.system_id)
120  return;
121  }
122  encryption_config->key_system_info.push_back(pssh_info);
123 }
124 
125 Status FillProtectionSystemInfo(const EncryptionParams& encryption_params,
126  const EncryptionKey& encryption_key,
127  EncryptionConfig* encryption_config) {
128  // If generating dummy keys for key rotation, don't generate PSSH info.
129  if (encryption_key.key_ids.empty())
130  return Status::OK;
131 
132  std::vector<std::unique_ptr<PsshGenerator>> pssh_generators;
133  std::vector<std::vector<uint8_t>> no_pssh_systems;
134  FillPsshGenerators(encryption_params, &pssh_generators, &no_pssh_systems);
135 
136  encryption_config->key_system_info = encryption_key.key_system_info;
137  for (const auto& pssh_generator : pssh_generators) {
138  const bool support_multiple_keys = pssh_generator->SupportMultipleKeys();
139  if (support_multiple_keys) {
140  ProtectionSystemSpecificInfo info;
141  RETURN_IF_ERROR(pssh_generator->GeneratePsshFromKeyIds(
142  encryption_key.key_ids, &info));
143  AddProtectionSystemIfNotExist(info, encryption_config);
144  } else {
145  ProtectionSystemSpecificInfo info;
146  RETURN_IF_ERROR(pssh_generator->GeneratePsshFromKeyIdAndKey(
147  encryption_key.key_id, encryption_key.key, &info));
148  AddProtectionSystemIfNotExist(info, encryption_config);
149  }
150  }
151 
152  for (const auto& no_pssh_system : no_pssh_systems) {
153  ProtectionSystemSpecificInfo info;
154  info.system_id = no_pssh_system;
155  AddProtectionSystemIfNotExist(info, encryption_config);
156  }
157 
158  return Status::OK;
159 }
160 
161 } // namespace
162 
163 EncryptionHandler::EncryptionHandler(const EncryptionParams& encryption_params,
164  KeySource* key_source)
165  : encryption_params_(encryption_params),
166  protection_scheme_(
167  static_cast<FourCC>(encryption_params.protection_scheme)),
168  key_source_(key_source),
169  subsample_generator_(
170  new SubsampleGenerator(encryption_params.vp9_subsample_encryption)),
171  encryptor_factory_(new AesEncryptorFactory) {}
172 
173 EncryptionHandler::~EncryptionHandler() = default;
174 
176  if (!encryption_params_.stream_label_func) {
177  return Status(error::INVALID_ARGUMENT, "Stream label function not set.");
178  }
179  if (num_input_streams() != 1 || next_output_stream_index() != 1) {
180  return Status(error::INVALID_ARGUMENT,
181  "Expects exactly one input and output.");
182  }
183  return Status::OK;
184 }
185 
186 Status EncryptionHandler::Process(std::unique_ptr<StreamData> stream_data) {
187  switch (stream_data->stream_data_type) {
188  case StreamDataType::kStreamInfo:
189  return ProcessStreamInfo(*stream_data->stream_info);
190  case StreamDataType::kSegmentInfo: {
191  std::shared_ptr<SegmentInfo> segment_info(new SegmentInfo(
192  *stream_data->segment_info));
193 
194  segment_info->is_encrypted = remaining_clear_lead_ <= 0;
195 
196  const bool key_rotation_enabled = crypto_period_duration_ != 0;
197  if (key_rotation_enabled)
198  segment_info->key_rotation_encryption_config = encryption_config_;
199  if (!segment_info->is_subsegment) {
200  if (key_rotation_enabled)
201  check_new_crypto_period_ = true;
202  if (remaining_clear_lead_ > 0)
203  remaining_clear_lead_ -= segment_info->duration;
204  }
205 
206  return DispatchSegmentInfo(kStreamIndex, segment_info);
207  }
208  case StreamDataType::kMediaSample:
209  return ProcessMediaSample(std::move(stream_data->media_sample));
210  default:
211  VLOG(3) << "Stream data type "
212  << static_cast<int>(stream_data->stream_data_type) << " ignored.";
213  return Dispatch(std::move(stream_data));
214  }
215 }
216 
217 Status EncryptionHandler::ProcessStreamInfo(const StreamInfo& clear_info) {
218  if (clear_info.is_encrypted()) {
219  return Status(error::INVALID_ARGUMENT,
220  "Input stream is already encrypted.");
221  }
222 
223  DCHECK_NE(kStreamUnknown, clear_info.stream_type());
224  DCHECK_NE(kStreamText, clear_info.stream_type());
225  std::shared_ptr<StreamInfo> stream_info = clear_info.Clone();
226  RETURN_IF_ERROR(
227  subsample_generator_->Initialize(protection_scheme_, *stream_info));
228 
229  remaining_clear_lead_ =
230  encryption_params_.clear_lead_in_seconds * stream_info->time_scale();
231  crypto_period_duration_ =
232  encryption_params_.crypto_period_duration_in_seconds *
233  stream_info->time_scale();
234  codec_ = stream_info->codec();
235  stream_label_ = GetStreamLabelForEncryption(
236  *stream_info, encryption_params_.stream_label_func);
237 
238  SetupProtectionPattern(stream_info->stream_type());
239 
240  EncryptionKey encryption_key;
241  const bool key_rotation_enabled = crypto_period_duration_ != 0;
242  if (key_rotation_enabled) {
243  check_new_crypto_period_ = true;
244  // Setup dummy key id, key and iv to signal encryption for key rotation.
245  encryption_key.key_id.assign(std::begin(kKeyRotationDefaultKeyId),
246  std::end(kKeyRotationDefaultKeyId));
247  encryption_key.key.assign(std::begin(kKeyRotationDefaultKey),
248  std::end(kKeyRotationDefaultKey));
249  encryption_key.iv.assign(std::begin(kKeyRotationDefaultIv),
250  std::end(kKeyRotationDefaultIv));
251  } else {
252  RETURN_IF_ERROR(key_source_->GetKey(stream_label_, &encryption_key));
253  }
254  if (!CreateEncryptor(encryption_key))
255  return Status(error::ENCRYPTION_FAILURE, "Failed to create encryptor");
256 
257  stream_info->set_is_encrypted(true);
258  stream_info->set_has_clear_lead(encryption_params_.clear_lead_in_seconds > 0);
259  stream_info->set_encryption_config(*encryption_config_);
260 
261  return DispatchStreamInfo(kStreamIndex, stream_info);
262 }
263 
264 Status EncryptionHandler::ProcessMediaSample(
265  std::shared_ptr<const MediaSample> clear_sample) {
266  DCHECK(clear_sample);
267 
268  // Process the frame even if the frame is not encrypted as the next
269  // (encrypted) frame may be dependent on this clear frame.
270  std::vector<SubsampleEntry> subsamples;
271  RETURN_IF_ERROR(subsample_generator_->GenerateSubsamples(
272  clear_sample->data(), clear_sample->data_size(), &subsamples));
273 
274  // Need to setup the encryptor for new segments even if this segment does not
275  // need to be encrypted, so we can signal encryption metadata earlier to
276  // allows clients to prefetch the keys.
277  if (check_new_crypto_period_) {
278  // |dts| can be negative, e.g. after EditList adjustments. Normalized to 0
279  // in that case.
280  const int64_t dts = std::max(clear_sample->dts(), static_cast<int64_t>(0));
281  const int64_t current_crypto_period_index = dts / crypto_period_duration_;
282  const uint32_t crypto_period_duration_in_seconds =
283  static_cast<uint32_t>(encryption_params_.crypto_period_duration_in_seconds);
284  if (current_crypto_period_index != prev_crypto_period_index_) {
285  EncryptionKey encryption_key;
286  RETURN_IF_ERROR(key_source_->GetCryptoPeriodKey(
287  current_crypto_period_index, crypto_period_duration_in_seconds,
288  stream_label_, &encryption_key));
289  if (!CreateEncryptor(encryption_key))
290  return Status(error::ENCRYPTION_FAILURE, "Failed to create encryptor");
291  prev_crypto_period_index_ = current_crypto_period_index;
292  }
293  check_new_crypto_period_ = false;
294  }
295 
296  // Since there is no encryption needed right now, send the clear copy
297  // downstream so we can save the costs of copying it.
298  if (remaining_clear_lead_ > 0) {
299  return DispatchMediaSample(kStreamIndex, std::move(clear_sample));
300  }
301 
302  std::shared_ptr<uint8_t> cipher_sample_data(
303  new uint8_t[clear_sample->data_size()], std::default_delete<uint8_t[]>());
304 
305  const uint8_t* source = clear_sample->data();
306  uint8_t* dest = cipher_sample_data.get();
307  if (!subsamples.empty()) {
308  size_t total_size = 0;
309  for (const SubsampleEntry& subsample : subsamples) {
310  if (subsample.clear_bytes > 0) {
311  memcpy(dest, source, subsample.clear_bytes);
312  source += subsample.clear_bytes;
313  dest += subsample.clear_bytes;
314  total_size += subsample.clear_bytes;
315  }
316  if (subsample.cipher_bytes > 0) {
317  EncryptBytes(source, subsample.cipher_bytes, dest);
318  source += subsample.cipher_bytes;
319  dest += subsample.cipher_bytes;
320  total_size += subsample.cipher_bytes;
321  }
322  }
323  DCHECK_EQ(total_size, clear_sample->data_size());
324  } else {
325  EncryptBytes(source, clear_sample->data_size(), dest);
326  }
327 
328  std::shared_ptr<MediaSample> cipher_sample(clear_sample->Clone());
329  cipher_sample->TransferData(std::move(cipher_sample_data),
330  clear_sample->data_size());
331 
332  // Finish initializing the sample before sending it downstream. We must
333  // wait until now to finish the initialization as we will lose access to
334  // |decrypt_config| once we set it.
335  cipher_sample->set_is_encrypted(true);
336  std::unique_ptr<DecryptConfig> decrypt_config(new DecryptConfig(
337  encryption_config_->key_id, encryptor_->iv(), subsamples,
338  protection_scheme_, crypt_byte_block_, skip_byte_block_));
339  cipher_sample->set_decrypt_config(std::move(decrypt_config));
340 
341  encryptor_->UpdateIv();
342 
343  return DispatchMediaSample(kStreamIndex, std::move(cipher_sample));
344 }
345 
346 void EncryptionHandler::SetupProtectionPattern(StreamType stream_type) {
347  if (stream_type == kStreamVideo &&
348  IsPatternEncryptionScheme(protection_scheme_)) {
349  crypt_byte_block_ = encryption_params_.crypt_byte_block;
350  skip_byte_block_ = encryption_params_.skip_byte_block;
351  } else {
352  // Audio stream in pattern encryption scheme does not use pattern; it uses
353  // whole-block full sample encryption instead. Non-pattern encryption does
354  // not have pattern.
355  crypt_byte_block_ = 0u;
356  skip_byte_block_ = 0u;
357  }
358 }
359 
360 bool EncryptionHandler::CreateEncryptor(const EncryptionKey& encryption_key) {
361  std::unique_ptr<AesCryptor> encryptor = encryptor_factory_->CreateEncryptor(
362  protection_scheme_, crypt_byte_block_, skip_byte_block_, codec_,
363  encryption_key.key, encryption_key.iv);
364  if (!encryptor)
365  return false;
366  encryptor_ = std::move(encryptor);
367 
368  encryption_config_.reset(new EncryptionConfig);
369  encryption_config_->protection_scheme = protection_scheme_;
370  encryption_config_->crypt_byte_block = crypt_byte_block_;
371  encryption_config_->skip_byte_block = skip_byte_block_;
372 
373  const std::vector<uint8_t>& iv = encryptor_->iv();
374  if (encryptor_->use_constant_iv()) {
375  encryption_config_->per_sample_iv_size = 0;
376  encryption_config_->constant_iv = iv;
377  } else {
378  encryption_config_->per_sample_iv_size = static_cast<uint8_t>(iv.size());
379  }
380 
381  encryption_config_->key_id = encryption_key.key_id;
382  const auto status = FillProtectionSystemInfo(
383  encryption_params_, encryption_key, encryption_config_.get());
384  return status.ok();
385 }
386 
387 void EncryptionHandler::EncryptBytes(const uint8_t* source,
388  size_t source_size,
389  uint8_t* dest) {
390  DCHECK(source);
391  DCHECK(dest);
392  DCHECK(encryptor_);
393  CHECK(encryptor_->Crypt(source, source_size, dest));
394 }
395 
396 void EncryptionHandler::InjectSubsampleGeneratorForTesting(
397  std::unique_ptr<SubsampleGenerator> generator) {
398  subsample_generator_ = std::move(generator);
399 }
400 
401 void EncryptionHandler::InjectEncryptorFactoryForTesting(
402  std::unique_ptr<AesEncryptorFactory> encryptor_factory) {
403  encryptor_factory_ = std::move(encryptor_factory);
404 }
405 
406 } // namespace media
407 } // namespace shaka
shaka::ProtectionSystem::kCommon
@ kCommon
The common key system from EME: https://goo.gl/s8RIhr.
shaka::media::SegmentInfo
Definition: media_handler.h:55
shaka::media::MediaHandler::DispatchMediaSample
Status DispatchMediaSample(size_t stream_index, std::shared_ptr< const MediaSample > media_sample) const
Dispatch the media sample to downstream handlers.
Definition: media_handler.h:207
shaka::EncryptionParams::crypt_byte_block
uint8_t crypt_byte_block
Definition: crypto_params.h:174
shaka
All the methods that are virtual are virtual for mocking.
Definition: gflags_hex_bytes.cc:11
shaka::media::MediaHandler::DispatchStreamInfo
Status DispatchStreamInfo(size_t stream_index, std::shared_ptr< const StreamInfo > stream_info) const
Dispatch the stream info to downstream handlers.
Definition: media_handler.h:199
shaka::media::KeySource::GetKey
virtual Status GetKey(const std::string &stream_label, EncryptionKey *key)=0
shaka::Status
Definition: status.h:110
shaka::media::KeySource::GetCryptoPeriodKey
virtual Status GetCryptoPeriodKey(uint32_t crypto_period_index, uint32_t crypto_period_duration_in_seconds, const std::string &stream_label, EncryptionKey *key)=0
shaka::media::StreamInfo::Clone
virtual std::unique_ptr< StreamInfo > Clone() const =0
shaka::media::MediaHandler::Dispatch
Status Dispatch(std::unique_ptr< StreamData > stream_data) const
Definition: media_handler.cc:94
shaka::media::EncryptionHandler::InitializeInternal
Status InitializeInternal() override
Definition: encryption_handler.cc:175
shaka::EncryptionParams::clear_lead_in_seconds
double clear_lead_in_seconds
Clear lead duration in seconds.
Definition: crypto_params.h:162
shaka::media::EncryptionHandler::Process
Status Process(std::unique_ptr< StreamData > stream_data) override
Definition: encryption_handler.cc:186
shaka::EncryptionParams::stream_label_func
std::function< std::string(const EncryptedStreamAttributes &stream_attributes)> stream_label_func
Definition: crypto_params.h:216
shaka::EncryptionParams::skip_byte_block
uint8_t skip_byte_block
Definition: crypto_params.h:178
shaka::media::MediaHandler::DispatchSegmentInfo
Status DispatchSegmentInfo(size_t stream_index, std::shared_ptr< const SegmentInfo > segment_info) const
Dispatch the segment info to downstream handlers.
Definition: media_handler.h:224
shaka::media::StreamInfo
Abstract class holds stream information.
Definition: stream_info.h:65