DASH Media Packaging SDK
 All Classes Namespaces Functions Variables Typedefs Enumerations Enumerator
aes_decryptor.cc
1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include "packager/media/base/aes_decryptor.h"
8 
9 #include <openssl/aes.h>
10 #include <algorithm>
11 #include "packager/base/logging.h"
12 
13 namespace {
14 
15 // AES defines three key sizes: 128, 192 and 256 bits.
16 bool IsKeySizeValidForAes(size_t key_size) {
17  return key_size == 16 || key_size == 24 || key_size == 32;
18 }
19 
20 } // namespace
21 
22 namespace shaka {
23 namespace media {
24 
25 AesCbcDecryptor::AesCbcDecryptor(CbcPaddingScheme padding_scheme)
26  : AesCbcDecryptor(padding_scheme, kDontUseConstantIv) {}
27 
28 AesCbcDecryptor::AesCbcDecryptor(CbcPaddingScheme padding_scheme,
29  ConstantIvFlag constant_iv_flag)
30  : AesCryptor(constant_iv_flag), padding_scheme_(padding_scheme) {
31  if (padding_scheme_ != kNoPadding) {
32  CHECK_EQ(constant_iv_flag, kUseConstantIv)
33  << "non-constant iv (cipher block chain across calls) only makes sense "
34  "if the padding_scheme is kNoPadding.";
35  }
36 }
37 
38 AesCbcDecryptor::~AesCbcDecryptor() {}
39 
40 bool AesCbcDecryptor::InitializeWithIv(const std::vector<uint8_t>& key,
41  const std::vector<uint8_t>& iv) {
42  if (!IsKeySizeValidForAes(key.size())) {
43  LOG(ERROR) << "Invalid AES key size: " << key.size();
44  return false;
45  }
46 
47  CHECK_EQ(AES_set_decrypt_key(key.data(), key.size() * 8, mutable_aes_key()),
48  0);
49  return SetIv(iv);
50 }
51 
52 bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
53  size_t ciphertext_size,
54  uint8_t* plaintext,
55  size_t* plaintext_size) {
56  DCHECK(plaintext_size);
57  DCHECK(aes_key());
58  // Plaintext size is the same as ciphertext size except for pkcs5 padding.
59  // Will update later if using pkcs5 padding. For pkcs5 padding, we still
60  // need at least |ciphertext_size| bytes for intermediate operation.
61  if (*plaintext_size < ciphertext_size) {
62  LOG(ERROR) << "Expecting output size of at least " << ciphertext_size
63  << " bytes.";
64  return false;
65  }
66  *plaintext_size = ciphertext_size;
67 
68  if (ciphertext_size == 0) {
69  if (padding_scheme_ == kPkcs5Padding) {
70  LOG(ERROR) << "Expected ciphertext to be at least " << AES_BLOCK_SIZE
71  << " bytes with Pkcs5 padding.";
72  return false;
73  }
74  return true;
75  }
76  DCHECK(plaintext);
77 
78  const size_t residual_block_size = ciphertext_size % AES_BLOCK_SIZE;
79  const size_t cbc_size = ciphertext_size - residual_block_size;
80  if (residual_block_size == 0) {
81  AES_cbc_encrypt(ciphertext, plaintext, ciphertext_size, aes_key(),
82  internal_iv_.data(), AES_DECRYPT);
83  if (padding_scheme_ != kPkcs5Padding)
84  return true;
85 
86  // Strip off PKCS5 padding bytes.
87  const uint8_t num_padding_bytes = plaintext[ciphertext_size - 1];
88  if (num_padding_bytes > AES_BLOCK_SIZE) {
89  LOG(ERROR) << "Padding length is too large : "
90  << static_cast<int>(num_padding_bytes);
91  return false;
92  }
93  *plaintext_size -= num_padding_bytes;
94  return true;
95  } else if (padding_scheme_ == kNoPadding) {
96  AES_cbc_encrypt(ciphertext, plaintext, cbc_size, aes_key(),
97  internal_iv_.data(), AES_DECRYPT);
98 
99  // The residual block is not encrypted.
100  memcpy(plaintext + cbc_size, ciphertext + cbc_size, residual_block_size);
101  return true;
102  } else if (padding_scheme_ != kCtsPadding) {
103  LOG(ERROR) << "Expecting cipher text size to be multiple of "
104  << AES_BLOCK_SIZE << ", got " << ciphertext_size;
105  return false;
106  }
107 
108  DCHECK_EQ(padding_scheme_, kCtsPadding);
109  if (ciphertext_size < AES_BLOCK_SIZE) {
110  // Don't have a full block, leave unencrypted.
111  memcpy(plaintext, ciphertext, ciphertext_size);
112  return true;
113  }
114 
115  // AES-CBC decrypt everything up to the next-to-last full block.
116  if (cbc_size > AES_BLOCK_SIZE) {
117  AES_cbc_encrypt(ciphertext, plaintext, cbc_size - AES_BLOCK_SIZE, aes_key(),
118  internal_iv_.data(), AES_DECRYPT);
119  }
120 
121  const uint8_t* next_to_last_ciphertext_block =
122  ciphertext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;
123  uint8_t* next_to_last_plaintext_block =
124  plaintext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;
125 
126  // Determine what the last IV should be so that we can "skip ahead" in the
127  // CBC decryption.
128  std::vector<uint8_t> last_iv(
129  ciphertext + ciphertext_size - residual_block_size,
130  ciphertext + ciphertext_size);
131  last_iv.resize(AES_BLOCK_SIZE, 0);
132 
133  // Decrypt the next-to-last block using the IV determined above. This decrypts
134  // the residual block bits.
135  AES_cbc_encrypt(next_to_last_ciphertext_block, next_to_last_plaintext_block,
136  AES_BLOCK_SIZE, aes_key(), last_iv.data(), AES_DECRYPT);
137 
138  // Swap back the residual block bits and the next-to-last block.
139  if (plaintext == ciphertext) {
140  std::swap_ranges(next_to_last_plaintext_block,
141  next_to_last_plaintext_block + residual_block_size,
142  next_to_last_plaintext_block + AES_BLOCK_SIZE);
143  } else {
144  memcpy(next_to_last_plaintext_block + AES_BLOCK_SIZE,
145  next_to_last_plaintext_block, residual_block_size);
146  memcpy(next_to_last_plaintext_block,
147  next_to_last_ciphertext_block + AES_BLOCK_SIZE, residual_block_size);
148  }
149 
150  // Decrypt the next-to-last full block.
151  AES_cbc_encrypt(next_to_last_plaintext_block, next_to_last_plaintext_block,
152  AES_BLOCK_SIZE, aes_key(), internal_iv_.data(), AES_DECRYPT);
153  return true;
154 }
155 
156 void AesCbcDecryptor::SetIvInternal() {
157  internal_iv_ = iv();
158  internal_iv_.resize(AES_BLOCK_SIZE, 0);
159 }
160 
161 } // namespace media
162 } // namespace shaka
Class which implements AES-CBC (Cipher block chaining) decryption.
Definition: aes_decryptor.h:25
AesCbcDecryptor(CbcPaddingScheme padding_scheme)
const std::vector< uint8_t > & iv() const
Definition: aes_cryptor.h:81
bool InitializeWithIv(const std::vector< uint8_t > &key, const std::vector< uint8_t > &iv) override
bool SetIv(const std::vector< uint8_t > &iv)
Definition: aes_cryptor.cc:67