fix: remove extra block assumptions in mbedtls integration (#1323)

The current mbedtls integration was not working for some modes. See for
example #1316 and also lots of failing integration tests.

For example in pattern encryptor it works on one block at a time so it
cannot assume it's going to always get a buffer with a padding for an
extra block.

From what I can tell when the padding mode is correctly set to
`MBEDTLS_PADDING_NONE` there is no extra block being written to or
required.

This passes all crypto unit tests and integration tests.

Closes #1316
This commit is contained in:
Cosmin Stejerean 2024-02-08 19:16:52 +01:00 committed by GitHub
parent 9b9adf38ff
commit db59ad582a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 71 additions and 98 deletions

View File

@ -43,8 +43,7 @@ bool AesCryptor::Crypt(const std::vector<uint8_t>& text,
// Save text size to make it work for in-place conversion, since the
// next statement will update the text size.
const size_t text_size = text.size();
// mbedtls requires an extra block's worth of output buffer available.
crypt_text->resize(text_size + NumPaddingBytes(text_size) + AES_BLOCK_SIZE);
crypt_text->resize(text_size + NumPaddingBytes(text_size));
size_t crypt_text_size = crypt_text->size();
if (!Crypt(text.data(), text_size, crypt_text->data(), &crypt_text_size)) {
return false;
@ -58,8 +57,7 @@ bool AesCryptor::Crypt(const std::string& text, std::string* crypt_text) {
// Save text size to make it work for in-place conversion, since the
// next statement will update the text size.
const size_t text_size = text.size();
// mbedtls requires an extra block's worth of output buffer available.
crypt_text->resize(text_size + NumPaddingBytes(text_size) + AES_BLOCK_SIZE);
crypt_text->resize(text_size + NumPaddingBytes(text_size));
size_t crypt_text_size = crypt_text->size();
if (!Crypt(reinterpret_cast<const uint8_t*>(text.data()), text_size,
reinterpret_cast<uint8_t*>(&(*crypt_text)[0]), &crypt_text_size))

View File

@ -48,8 +48,7 @@ bool AesCbcDecryptor::InitializeWithIv(const std::vector<uint8_t>& key,
}
size_t AesCbcDecryptor::RequiredOutputSize(size_t plaintext_size) {
// mbedtls requires a buffer large enough for one extra block.
return plaintext_size + AES_BLOCK_SIZE;
return plaintext_size;
}
bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
@ -60,14 +59,12 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
// Plaintext size is the same as ciphertext size except for pkcs5 padding.
// Will update later if using pkcs5 padding. For pkcs5 padding, we still
// need at least |ciphertext_size| bytes for intermediate operation.
// mbedtls requires a buffer large enough for one extra block.
const size_t required_plaintext_size = ciphertext_size + AES_BLOCK_SIZE;
if (*plaintext_size < required_plaintext_size) {
LOG(ERROR) << "Expecting output size of at least "
<< required_plaintext_size << " bytes.";
if (*plaintext_size < ciphertext_size) {
LOG(ERROR) << "Expecting output size of at least " << ciphertext_size
<< " bytes.";
return false;
}
*plaintext_size = required_plaintext_size - AES_BLOCK_SIZE;
*plaintext_size = ciphertext_size;
// If the ciphertext size is 0, this can be a no-op decrypt, so long as the
// padding mode isn't PKCS5.
@ -83,15 +80,9 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
const size_t residual_block_size = ciphertext_size % AES_BLOCK_SIZE;
const size_t cbc_size = ciphertext_size - residual_block_size;
// Copy the residual block early, since mbedtls may overwrite one extra block
// of the output, and input and output may be the same buffer.
std::vector<uint8_t> residual_block(ciphertext + cbc_size,
ciphertext + ciphertext_size);
DCHECK_EQ(residual_block.size(), residual_block_size);
if (residual_block_size == 0) {
CbcDecryptBlocks(ciphertext, ciphertext_size, plaintext);
CbcDecryptBlocks(ciphertext, ciphertext_size, plaintext,
internal_iv_.data());
if (padding_scheme_ != kPkcs5Padding)
return true;
@ -105,10 +96,11 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
*plaintext_size -= num_padding_bytes;
return true;
} else if (padding_scheme_ == kNoPadding) {
CbcDecryptBlocks(ciphertext, cbc_size, plaintext);
if (cbc_size > 0) {
CbcDecryptBlocks(ciphertext, cbc_size, plaintext, internal_iv_.data());
}
// The residual block is not encrypted.
memcpy(plaintext + cbc_size, residual_block.data(), residual_block_size);
memcpy(plaintext + cbc_size, ciphertext + cbc_size, residual_block_size);
return true;
} else if (padding_scheme_ != kCtsPadding) {
LOG(ERROR) << "Expecting cipher text size to be multiple of "
@ -123,49 +115,44 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
return true;
}
// Copy the next-to-last block early, since mbedtls may overwrite one extra
// block of the output, and input and output may be the same buffer.
// NOTE: Before this point, there may not be such a block. Here, we know
// this is safe.
std::vector<uint8_t> next_to_last_block(
ciphertext + cbc_size - AES_BLOCK_SIZE, ciphertext + cbc_size);
// AES-CBC decrypt everything up to the next-to-last full block.
if (cbc_size > AES_BLOCK_SIZE) {
CbcDecryptBlocks(ciphertext, cbc_size - AES_BLOCK_SIZE, plaintext);
CbcDecryptBlocks(ciphertext, cbc_size - AES_BLOCK_SIZE, plaintext,
internal_iv_.data());
}
uint8_t* next_to_last_plaintext_block = plaintext + cbc_size - AES_BLOCK_SIZE;
const uint8_t* next_to_last_ciphertext_block =
ciphertext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;
uint8_t* next_to_last_plaintext_block =
plaintext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;
// The next-to-last block should be decrypted first in ECB mode, which is
// effectively what you get with an IV of all zeroes.
std::vector<uint8_t> backup_iv(internal_iv_);
internal_iv_.assign(AES_BLOCK_SIZE, 0);
// mbedtls requires a buffer large enough for one extra block.
std::vector<uint8_t> stolen_bits(AES_BLOCK_SIZE * 2);
CbcDecryptBlocks(next_to_last_block.data(), AES_BLOCK_SIZE,
stolen_bits.data());
// Determine what the last IV should be so that we can "skip ahead" in the
// CBC decryption.
std::vector<uint8_t> last_iv(
ciphertext + ciphertext_size - residual_block_size,
ciphertext + ciphertext_size);
last_iv.resize(AES_BLOCK_SIZE, 0);
// Reconstruct the final two blocks of ciphertext.
std::vector<uint8_t> reconstructed_blocks(AES_BLOCK_SIZE * 2);
memcpy(reconstructed_blocks.data(), residual_block.data(),
residual_block_size);
memcpy(reconstructed_blocks.data() + residual_block_size,
stolen_bits.data() + residual_block_size,
AES_BLOCK_SIZE - residual_block_size);
memcpy(reconstructed_blocks.data() + AES_BLOCK_SIZE,
next_to_last_block.data(), AES_BLOCK_SIZE);
// Decrypt the next-to-last block using the IV determined above. This decrypts
// the residual block bits.
CbcDecryptBlocks(next_to_last_ciphertext_block, AES_BLOCK_SIZE,
next_to_last_plaintext_block, last_iv.data());
// Decrypt the last two blocks.
internal_iv_ = backup_iv;
// mbedtls requires a buffer large enough for one extra block.
std::vector<uint8_t> final_output_blocks(AES_BLOCK_SIZE * 3);
CbcDecryptBlocks(reconstructed_blocks.data(), AES_BLOCK_SIZE * 2,
final_output_blocks.data());
// Swap back the residual block bits and the next-to-last block.
if (plaintext == ciphertext) {
std::swap_ranges(next_to_last_plaintext_block,
next_to_last_plaintext_block + residual_block_size,
next_to_last_plaintext_block + AES_BLOCK_SIZE);
} else {
memcpy(next_to_last_plaintext_block + AES_BLOCK_SIZE,
next_to_last_plaintext_block, residual_block_size);
memcpy(next_to_last_plaintext_block,
next_to_last_ciphertext_block + AES_BLOCK_SIZE, residual_block_size);
}
// Copy the final output.
memcpy(next_to_last_plaintext_block, final_output_blocks.data(),
AES_BLOCK_SIZE + residual_block_size);
// Decrypt the next-to-last full block.
CbcDecryptBlocks(next_to_last_plaintext_block, AES_BLOCK_SIZE,
next_to_last_plaintext_block, internal_iv_.data());
return true;
}
@ -176,7 +163,8 @@ void AesCbcDecryptor::SetIvInternal() {
void AesCbcDecryptor::CbcDecryptBlocks(const uint8_t* ciphertext,
size_t ciphertext_size,
uint8_t* plaintext) {
uint8_t* plaintext,
uint8_t* iv) {
CHECK_EQ(ciphertext_size % AES_BLOCK_SIZE, 0u);
CHECK_GT(ciphertext_size, 0u);
@ -186,14 +174,12 @@ void AesCbcDecryptor::CbcDecryptBlocks(const uint8_t* ciphertext,
std::vector<uint8_t> next_iv(last_block, last_block + AES_BLOCK_SIZE);
size_t output_size = 0;
CHECK_EQ(mbedtls_cipher_crypt(&cipher_ctx_, internal_iv_.data(),
AES_BLOCK_SIZE, ciphertext, ciphertext_size,
plaintext, &output_size),
CHECK_EQ(mbedtls_cipher_crypt(&cipher_ctx_, iv, AES_BLOCK_SIZE, ciphertext,
ciphertext_size, plaintext, &output_size),
0);
DCHECK_EQ(output_size % AES_BLOCK_SIZE, 0u);
// Update the internal IV.
internal_iv_ = next_iv;
memcpy(iv, next_iv.data(), next_iv.size());
}
} // namespace media

View File

@ -58,7 +58,8 @@ class AesCbcDecryptor : public AesCryptor {
void SetIvInternal() override;
void CbcDecryptBlocks(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext);
uint8_t* ciphertext,
uint8_t* iv);
const CbcPaddingScheme padding_scheme_;
// 16-byte internal iv for crypto operations.

View File

@ -33,8 +33,7 @@ namespace media {
AesCtrEncryptor::AesCtrEncryptor()
: AesCryptor(kDontUseConstantIv),
block_offset_(0),
// mbedtls requires an extra output block.
encrypted_counter_(AES_BLOCK_SIZE * 2, 0) {}
encrypted_counter_(AES_BLOCK_SIZE, 0) {}
AesCtrEncryptor::~AesCtrEncryptor() {}
@ -129,8 +128,7 @@ bool AesCbcEncryptor::InitializeWithIv(const std::vector<uint8_t>& key,
}
size_t AesCbcEncryptor::RequiredOutputSize(size_t plaintext_size) {
// mbedtls requires a buffer large enough for one extra block.
return plaintext_size + NumPaddingBytes(plaintext_size) + AES_BLOCK_SIZE;
return plaintext_size + NumPaddingBytes(plaintext_size);
}
bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
@ -146,19 +144,12 @@ bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
<< required_ciphertext_size << " bytes.";
return false;
}
*ciphertext_size = required_ciphertext_size - AES_BLOCK_SIZE;
*ciphertext_size = required_ciphertext_size;
// Encrypt everything but the residual block using CBC.
const size_t cbc_size = plaintext_size - residual_block_size;
// Copy the residual block early, since mbedtls may overwrite one extra block
// of the output, and input and output may be the same buffer.
std::vector<uint8_t> residual_block(plaintext + cbc_size,
plaintext + plaintext_size);
DCHECK_EQ(residual_block.size(), residual_block_size);
if (cbc_size != 0) {
CbcEncryptBlocks(plaintext, cbc_size, ciphertext);
CbcEncryptBlocks(plaintext, cbc_size, ciphertext, internal_iv_.data());
} else if (padding_scheme_ == kCtsPadding) {
// Don't have a full block, leave unencrypted.
memcpy(ciphertext, plaintext, plaintext_size);
@ -175,27 +166,26 @@ bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
return true;
}
std::vector<uint8_t> residual_block(plaintext + cbc_size,
plaintext + plaintext_size);
DCHECK_EQ(residual_block.size(), residual_block_size);
uint8_t* residual_ciphertext_block = ciphertext + cbc_size;
if (padding_scheme_ == kPkcs5Padding) {
DCHECK_EQ(num_padding_bytes, AES_BLOCK_SIZE - residual_block_size);
// Pad residue block with PKCS5 padding.
residual_block.resize(AES_BLOCK_SIZE, static_cast<char>(num_padding_bytes));
CbcEncryptBlocks(residual_block.data(), AES_BLOCK_SIZE,
residual_ciphertext_block);
residual_ciphertext_block, internal_iv_.data());
} else {
DCHECK_EQ(num_padding_bytes, 0u);
DCHECK_EQ(padding_scheme_, kCtsPadding);
// Zero-pad the residual block and encrypt using CBC.
residual_block.resize(AES_BLOCK_SIZE, 0);
// mbedtls requires an extra block in the output buffer, and it cannot be
// the same as the input buffer.
std::vector<uint8_t> encrypted_residual_block(AES_BLOCK_SIZE * 2);
CbcEncryptBlocks(residual_block.data(), AES_BLOCK_SIZE,
encrypted_residual_block.data());
residual_block.data(), internal_iv_.data());
// Replace the last full block with the zero-padded, encrypted residual
// block, and replace the residual block with the equivalent portion of the
@ -206,8 +196,8 @@ bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
// https://en.wikipedia.org/wiki/Ciphertext_stealing#CS2
memcpy(residual_ciphertext_block,
residual_ciphertext_block - AES_BLOCK_SIZE, residual_block_size);
memcpy(residual_ciphertext_block - AES_BLOCK_SIZE,
encrypted_residual_block.data(), AES_BLOCK_SIZE);
memcpy(residual_ciphertext_block - AES_BLOCK_SIZE, residual_block.data(),
AES_BLOCK_SIZE);
}
return true;
}
@ -225,20 +215,20 @@ size_t AesCbcEncryptor::NumPaddingBytes(size_t size) const {
void AesCbcEncryptor::CbcEncryptBlocks(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext) {
uint8_t* ciphertext,
uint8_t* iv) {
CHECK_EQ(plaintext_size % AES_BLOCK_SIZE, 0u);
size_t output_size = 0;
CHECK_EQ(
mbedtls_cipher_crypt(&cipher_ctx_, internal_iv_.data(), AES_BLOCK_SIZE,
plaintext, plaintext_size, ciphertext, &output_size),
CHECK_EQ(mbedtls_cipher_crypt(&cipher_ctx_, iv, AES_BLOCK_SIZE, plaintext,
plaintext_size, ciphertext, &output_size),
0);
CHECK_EQ(output_size % AES_BLOCK_SIZE, 0u);
CHECK_GT(output_size, 0u);
uint8_t* last_block = ciphertext + output_size - AES_BLOCK_SIZE;
internal_iv_.assign(last_block, last_block + AES_BLOCK_SIZE);
memcpy(iv, last_block, AES_BLOCK_SIZE);
}
} // namespace media

View File

@ -96,7 +96,8 @@ class AesCbcEncryptor : public AesCryptor {
void CbcEncryptBlocks(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext);
uint8_t* ciphertext,
uint8_t* iv);
const CbcPaddingScheme padding_scheme_;
// 16-byte internal iv for crypto operations.

View File

@ -75,8 +75,7 @@ void AesEcbEncrypt(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& plaintext,
std::vector<uint8_t>* ciphertext) {
CHECK_EQ(plaintext.size() % AES_BLOCK_SIZE, 0u);
// mbedtls requires an extra block worth of output buffer.
ciphertext->resize(plaintext.size() + AES_BLOCK_SIZE);
ciphertext->resize(plaintext.size());
mbedtls_cipher_context_t ctx;
mbedtls_cipher_init(&ctx);
@ -98,8 +97,6 @@ void AesEcbEncrypt(const std::vector<uint8_t>& key,
plaintext.data(), plaintext.size(),
ciphertext->data(), &output_size),
0);
// Truncate the output to the correct size.
ciphertext->resize(plaintext.size());
mbedtls_cipher_free(&ctx);
}