DASH Media Packaging SDK
 All Classes Namespaces Functions Variables Typedefs Enumerator
aes_decryptor.cc
1 // Copyright 2016 Google Inc. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include "packager/media/base/aes_decryptor.h"
8 
9 #include <openssl/aes.h>
10 
11 #include "packager/base/logging.h"
12 
13 namespace {
14 
15 // AES defines three key sizes: 128, 192 and 256 bits.
16 bool IsKeySizeValidForAes(size_t key_size) {
17  return key_size == 16 || key_size == 24 || key_size == 32;
18 }
19 
20 } // namespace
21 
22 namespace edash_packager {
23 namespace media {
24 
25 AesCbcDecryptor::AesCbcDecryptor(CbcPaddingScheme padding_scheme,
26  bool chain_across_calls)
27  : padding_scheme_(padding_scheme),
28  chain_across_calls_(chain_across_calls) {
29  if (padding_scheme_ != kNoPadding) {
30  CHECK(!chain_across_calls) << "cipher block chain across calls only makes "
31  "sense if the padding_scheme is kNoPadding.";
32  }
33 }
34 
35 AesCbcDecryptor::~AesCbcDecryptor() {}
36 
37 bool AesCbcDecryptor::InitializeWithIv(const std::vector<uint8_t>& key,
38  const std::vector<uint8_t>& iv) {
39  if (!IsKeySizeValidForAes(key.size())) {
40  LOG(ERROR) << "Invalid AES key size: " << key.size();
41  return false;
42  }
43 
44  CHECK_EQ(AES_set_decrypt_key(key.data(), key.size() * 8, mutable_aes_key()),
45  0);
46  return SetIv(iv);
47 }
48 
49 bool AesCbcDecryptor::SetIv(const std::vector<uint8_t>& iv) {
50  if (iv.size() != AES_BLOCK_SIZE) {
51  LOG(ERROR) << "Invalid IV size: " << iv.size();
52  return false;
53  }
54 
55  set_iv(iv);
56  return true;
57 }
58 
59 bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
60  size_t ciphertext_size,
61  uint8_t* plaintext,
62  size_t* plaintext_size) {
63  DCHECK(plaintext_size);
64  DCHECK(aes_key());
65  // Plaintext size is the same as ciphertext size except for pkcs5 padding.
66  // Will update later if using pkcs5 padding. For pkcs5 padding, we still
67  // need at least |ciphertext_size| bytes for intermediate operation.
68  if (*plaintext_size < ciphertext_size) {
69  LOG(ERROR) << "Expecting output size of at least " << ciphertext_size
70  << " bytes.";
71  return false;
72  }
73  *plaintext_size = ciphertext_size;
74 
75  if (ciphertext_size == 0) {
76  if (padding_scheme_ == kPkcs5Padding) {
77  LOG(ERROR) << "Expected ciphertext to be at least " << AES_BLOCK_SIZE
78  << " bytes with Pkcs5 padding.";
79  return false;
80  }
81  return true;
82  }
83  DCHECK(plaintext);
84 
85  std::vector<uint8_t> local_iv(iv());
86  const size_t residual_block_size = ciphertext_size % AES_BLOCK_SIZE;
87  const size_t cbc_size = ciphertext_size - residual_block_size;
88  if (residual_block_size == 0) {
89  AES_cbc_encrypt(ciphertext, plaintext, ciphertext_size, aes_key(),
90  local_iv.data(), AES_DECRYPT);
91  if (chain_across_calls_)
92  set_iv(local_iv);
93  if (padding_scheme_ != kPkcs5Padding)
94  return true;
95 
96  // Strip off PKCS5 padding bytes.
97  const uint8_t num_padding_bytes = plaintext[ciphertext_size - 1];
98  if (num_padding_bytes > AES_BLOCK_SIZE) {
99  LOG(ERROR) << "Padding length is too large : "
100  << static_cast<int>(num_padding_bytes);
101  return false;
102  }
103  *plaintext_size -= num_padding_bytes;
104  return true;
105  } else if (padding_scheme_ == kNoPadding) {
106  AES_cbc_encrypt(ciphertext, plaintext, cbc_size, aes_key(), local_iv.data(),
107  AES_DECRYPT);
108  if (chain_across_calls_)
109  set_iv(local_iv);
110 
111  // The residual block is not encrypted.
112  memcpy(plaintext + cbc_size, ciphertext + cbc_size, residual_block_size);
113  return true;
114  } else if (padding_scheme_ != kCtsPadding) {
115  LOG(ERROR) << "Expecting cipher text size to be multiple of "
116  << AES_BLOCK_SIZE << ", got " << ciphertext_size;
117  return false;
118  }
119 
120  DCHECK(!chain_across_calls_);
121  DCHECK_EQ(padding_scheme_, kCtsPadding);
122  if (ciphertext_size < AES_BLOCK_SIZE) {
123  // Don't have a full block, leave unencrypted.
124  memcpy(plaintext, ciphertext, ciphertext_size);
125  return true;
126  }
127 
128  // AES-CBC decrypt everything up to the next-to-last full block.
129  if (cbc_size > AES_BLOCK_SIZE) {
130  AES_cbc_encrypt(ciphertext, plaintext, cbc_size - AES_BLOCK_SIZE, aes_key(),
131  local_iv.data(), AES_DECRYPT);
132  }
133 
134  const uint8_t* next_to_last_ciphertext_block =
135  ciphertext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;
136  uint8_t* next_to_last_plaintext_block =
137  plaintext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;
138 
139  // Determine what the last IV should be so that we can "skip ahead" in the
140  // CBC decryption.
141  std::vector<uint8_t> last_iv(
142  ciphertext + ciphertext_size - residual_block_size,
143  ciphertext + ciphertext_size);
144  last_iv.resize(AES_BLOCK_SIZE, 0);
145 
146  // Decrypt the next-to-last block using the IV determined above. This decrypts
147  // the residual block bits.
148  AES_cbc_encrypt(next_to_last_ciphertext_block, next_to_last_plaintext_block,
149  AES_BLOCK_SIZE, aes_key(), last_iv.data(), AES_DECRYPT);
150 
151  // Swap back the residual block bits and the next-to-last block.
152  if (plaintext == ciphertext) {
153  std::swap_ranges(next_to_last_plaintext_block,
154  next_to_last_plaintext_block + residual_block_size,
155  next_to_last_plaintext_block + AES_BLOCK_SIZE);
156  } else {
157  memcpy(next_to_last_plaintext_block + AES_BLOCK_SIZE,
158  next_to_last_plaintext_block, residual_block_size);
159  memcpy(next_to_last_plaintext_block,
160  next_to_last_ciphertext_block + AES_BLOCK_SIZE, residual_block_size);
161  }
162 
163  // Decrypt the next-to-last full block.
164  AES_cbc_encrypt(next_to_last_plaintext_block, next_to_last_plaintext_block,
165  AES_BLOCK_SIZE, aes_key(), local_iv.data(), AES_DECRYPT);
166  return true;
167 }
168 
169 } // namespace media
170 } // namespace edash_packager
bool SetIv(const std::vector< uint8_t > &iv) override
bool InitializeWithIv(const std::vector< uint8_t > &key, const std::vector< uint8_t > &iv) override
AesCbcDecryptor(CbcPaddingScheme padding_scheme, bool chain_across_calls)
const std::vector< uint8_t > & iv() const
Definition: aes_cryptor.h:60