feat: Port AES and RSA crypto to mbedtls (#1119)

mbedtls works very differently from BoringSSL, and many changes had to
be made in the details of AES decryption to accomodate this.

Beyond the basic changes required for mbedtls, part of the CTS padding
implementation had to be rewritten. I believe this is because of an
assumption that held for BoringSSL, but not for mbedtls. I was unable to
determine what it was, so I rewrote the CTS decryption using reference
materials. After this, tests passed.

The deterministc PRNG I used with mbedtls in the RSA tests differs
somewhat from the old one, so the expected vectors had to be
regenerated. The old determinstic tests were also disabled, and are now
re-enabled.

Since cryptography is sensitive code, and because there were far more
changes needed here than just updating some headers and utility function
calls, this has been split into its own PR for separate review from the
rest of the media/base porting work.

Issue #1047 (CMake porting)
Issue #346 (absl porting)
This commit is contained in:
Joey Parrish 2022-11-02 08:34:06 -07:00 committed by GitHub
parent c3a4951597
commit 7b33f2065f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 633 additions and 439 deletions

View File

@ -32,4 +32,7 @@
/// You can use the insertion operator to add specific logs to this.
#define NOTIMPLEMENTED() LOG(ERROR) << "NOTIMPLEMENTED: "
/// AES block size in bytes, regardless of key size.
#define AES_BLOCK_SIZE 16
#endif // PACKAGER_MACROS_H_

View File

@ -5,4 +5,5 @@
# https://developers.google.com/open-source/licenses/bsd
# Subdirectories with their own CMakeLists.txt, all of whose targets are built.
add_subdirectory(base)
add_subdirectory(test)

View File

@ -0,0 +1,98 @@
# Copyright 2022 Google LLC. All rights reserved.
#
# Use of this source code is governed by a BSD-style
# license that can be found in the LICENSE file or at
# https://developers.google.com/open-source/licenses/bsd
# TODO: Add widevine_protos
add_library(media_base STATIC
# TODO: finish media_base
aes_cryptor.cc
aes_decryptor.cc
aes_encryptor.cc
aes_pattern_cryptor.cc
#audio_stream_info.cc
#audio_timestamp_helper.cc
#bit_reader.cc
#bit_writer.cc
#buffer_reader.cc
#buffer_writer.cc
#byte_queue.cc
#cc_stream_filter.cc
#closure_thread.cc
#common_pssh_generator.cc
#container_names.cc
#decrypt_config.cc
#decryptor_source.cc
#http_key_fetcher.cc
#id3_tag.cc
#key_fetcher.cc
#key_source.cc
#language_utils.cc
#media_handler.cc
#media_sample.cc
#muxer.cc
#muxer_options.cc
#muxer_util.cc
#network_util.cc
#offset_byte_queue.cc
#playready_key_source.cc
#playready_pssh_generator.cc
#protection_system_specific_info.cc
#proto_json_util.cc
#pssh_generator.cc
#pssh_generator_util.cc
#raw_key_source.cc
#request_signer.cc
rsa_key.cc
#stream_info.cc
#text_muxer.cc
#text_sample.cc
#text_stream_info.cc
#text_track_config.cc
#video_stream_info.cc
#video_util.cc
#widevine_key_source.cc
#widevine_pssh_generator.cc
)
target_link_libraries(media_base
absl::base
absl::strings
glog
mbedtls)
# TODO: lib media_handler_test_base
add_executable(media_base_unittest
# TODO: finish media_base_unittest
aes_cryptor_unittest.cc
aes_pattern_cryptor_unittest.cc
#audio_timestamp_helper_unittest.cc
#bit_reader_unittest.cc
#bit_writer_unittest.cc
#buffer_writer_unittest.cc
#closure_thread_unittest.cc
#container_names_unittest.cc
#decryptor_source_unittest.cc
#http_key_fetcher_unittest.cc
#id3_tag_unittest.cc
#muxer_util_unittest.cc
#offset_byte_queue_unittest.cc
#producer_consumer_queue_unittest.cc
#protection_system_specific_info_unittest.cc
#pssh_generator_unittest.cc
#raw_key_source_unittest.cc
rsa_key_unittest.cc
test/rsa_test_data.cc
#video_util_unittest.cc
#widevine_key_source_unittest.cc
)
target_link_libraries(media_base_unittest
media_base
gmock
gtest
gtest_main
test_data_util)
add_test(NAME media_base_unittest COMMAND media_base_unittest)

View File

@ -6,15 +6,11 @@
#include "packager/media/base/aes_cryptor.h"
#include <openssl/aes.h>
#include <openssl/crypto.h>
#include <openssl/err.h>
#include <openssl/rand.h>
#include <string>
#include <vector>
#include "packager/base/logging.h"
#include "glog/logging.h"
#include "mbedtls/entropy.h"
namespace {
@ -30,20 +26,21 @@ namespace shaka {
namespace media {
AesCryptor::AesCryptor(ConstantIvFlag constant_iv_flag)
: aes_key_(new AES_KEY),
constant_iv_flag_(constant_iv_flag),
num_crypt_bytes_(0) {
CRYPTO_library_init();
: constant_iv_flag_(constant_iv_flag), num_crypt_bytes_(0) {
mbedtls_cipher_init(&cipher_ctx_);
}
AesCryptor::~AesCryptor() {}
AesCryptor::~AesCryptor() {
mbedtls_cipher_free(&cipher_ctx_);
}
bool AesCryptor::Crypt(const std::vector<uint8_t>& text,
std::vector<uint8_t>* crypt_text) {
// Save text size to make it work for in-place conversion, since the
// next statement will update the text size.
const size_t text_size = text.size();
crypt_text->resize(text_size + NumPaddingBytes(text_size));
// mbedtls requires an extra block's worth of output buffer available.
crypt_text->resize(text_size + NumPaddingBytes(text_size) + AES_BLOCK_SIZE);
size_t crypt_text_size = crypt_text->size();
if (!Crypt(text.data(), text_size, crypt_text->data(), &crypt_text_size)) {
return false;
@ -57,7 +54,8 @@ bool AesCryptor::Crypt(const std::string& text, std::string* crypt_text) {
// Save text size to make it work for in-place conversion, since the
// next statement will update the text size.
const size_t text_size = text.size();
crypt_text->resize(text_size + NumPaddingBytes(text_size));
// mbedtls requires an extra block's worth of output buffer available.
crypt_text->resize(text_size + NumPaddingBytes(text_size) + AES_BLOCK_SIZE);
size_t crypt_text_size = crypt_text->size();
if (!Crypt(reinterpret_cast<const uint8_t*>(text.data()), text_size,
reinterpret_cast<uint8_t*>(&(*crypt_text)[0]), &crypt_text_size))
@ -98,7 +96,7 @@ void AesCryptor::UpdateIv() {
increment = (num_crypt_bytes_ + AES_BLOCK_SIZE - 1) / AES_BLOCK_SIZE;
}
for (int i = iv_.size() - 1; increment > 0 && i >= 0; --i) {
for (int64_t i = iv_.size() - 1; increment > 0 && i >= 0; --i) {
increment += iv_[i];
iv_[i] = increment & 0xFF;
increment >>= 8;
@ -118,9 +116,14 @@ bool AesCryptor::GenerateRandomIv(FourCC protection_scheme,
? 8
: 16;
iv->resize(iv_size);
if (RAND_bytes(iv->data(), iv_size) != 1) {
LOG(ERROR) << "RAND_bytes failed with error: "
<< ERR_error_string(ERR_get_error(), NULL);
mbedtls_entropy_context entropy_ctx;
mbedtls_entropy_init(&entropy_ctx);
int rv = mbedtls_entropy_func(&entropy_ctx, iv->data(), iv_size);
mbedtls_entropy_free(&entropy_ctx);
if (rv != 0) {
LOG(ERROR) << "mbedtls_entropy_func failed with: " << rv;
return false;
}
return true;
@ -128,8 +131,56 @@ bool AesCryptor::GenerateRandomIv(FourCC protection_scheme,
size_t AesCryptor::NumPaddingBytes(size_t size) const {
// No padding by default.
UNUSED(size);
return 0;
}
bool AesCryptor::SetupCipher(size_t key_size, CipherMode mode) {
mbedtls_cipher_type_t type;
// AES defines three key sizes: 128, 192 and 256 bits.
// NOTE: Because we use ECB mode in the CTR cryptors, this returns ECB
// instead of CTR. Counters and block offsets are managed internally.
switch (key_size) {
case 16:
type = mode == kCtrMode ? MBEDTLS_CIPHER_AES_128_ECB
: MBEDTLS_CIPHER_AES_128_CBC;
break;
case 24:
type = mode == kCtrMode ? MBEDTLS_CIPHER_AES_192_ECB
: MBEDTLS_CIPHER_AES_192_CBC;
break;
case 32:
type = mode == kCtrMode ? MBEDTLS_CIPHER_AES_256_ECB
: MBEDTLS_CIPHER_AES_256_CBC;
break;
default:
LOG(ERROR) << "Invalid AES key size: " << key_size;
return false;
}
const mbedtls_cipher_info_t* cipher_info =
mbedtls_cipher_info_from_type(type);
CHECK(cipher_info);
if (mbedtls_cipher_setup(&cipher_ctx_, cipher_info) != 0) {
LOG(ERROR) << "Cipher setup failed";
return false;
}
// Padding mode only applies to CBC.
if (mode == kCbcMode) {
// We handle padding ourselves. Don't let mbedtls mess with it.
mbedtls_cipher_padding_t cipher_padding = MBEDTLS_PADDING_NONE;
if (mbedtls_cipher_set_padding_mode(&cipher_ctx_, cipher_padding) != 0) {
LOG(ERROR) << "Failed to set CBC padding mode";
return false;
}
}
return true;
}
} // namespace media
} // namespace shaka

View File

@ -11,12 +11,10 @@
#include <string>
#include <vector>
#include "packager/base/macros.h"
#include "mbedtls/cipher.h"
#include "packager/macros.h"
#include "packager/media/base/fourccs.h"
struct aes_key_st;
typedef struct aes_key_st AES_KEY;
namespace shaka {
namespace media {
@ -92,8 +90,15 @@ class AesCryptor {
std::vector<uint8_t>* iv);
protected:
const AES_KEY* aes_key() const { return aes_key_.get(); }
AES_KEY* mutable_aes_key() { return aes_key_.get(); }
enum CipherMode {
kCtrMode,
kCbcMode,
};
// mbedTLS cipher context.
mbedtls_cipher_context_t cipher_ctx_;
bool SetupCipher(size_t key_size, CipherMode mode);
private:
// Internal implementation of crypt function.
@ -119,9 +124,6 @@ class AesCryptor {
// Note: No paddings should be needed except for pkcs5-cbc encryptor.
virtual size_t NumPaddingBytes(size_t size) const;
// Openssl AES_KEY.
std::unique_ptr<AES_KEY> aes_key_;
// Indicates whether a constant iv is used. Internal iv will be reset to
// |iv_| before calling Crypt if that is the case.
const ConstantIvFlag constant_iv_flag_;

View File

@ -8,8 +8,8 @@
#include <memory>
#include "packager/base/logging.h"
#include "packager/base/strings/string_number_conversions.h"
#include "absl/strings/escaping.h"
#include "glog/logging.h"
#include "packager/media/base/aes_decryptor.h"
#include "packager/media/base/aes_encryptor.h"
@ -144,12 +144,11 @@ TEST_F(AesCtrEncryptorTest, NistTestCaseInplaceEncryptionDecryption) {
TEST_F(AesCtrEncryptorTest, EncryptDecryptString) {
static const char kPlaintext[] = "normal plaintext of random length";
static const char kExpectedCiphertextInHex[] =
"82E3AD1EF90C5CC09EB37F1B9EFBD99016441A1C15123F0777CD57BB993E14DA02";
"82e3ad1ef90c5cc09eb37f1b9efbd99016441a1c15123f0777cd57bb993e14da02";
std::string ciphertext;
ASSERT_TRUE(encryptor_.Crypt(kPlaintext, &ciphertext));
EXPECT_EQ(kExpectedCiphertextInHex,
base::HexEncode(ciphertext.data(), ciphertext.size()));
EXPECT_EQ(kExpectedCiphertextInHex, absl::BytesToHexString(ciphertext));
std::string decrypted;
ASSERT_TRUE(decryptor_.SetIv(iv_));
@ -202,7 +201,8 @@ TEST_F(AesCtrEncryptorTest, GenerateRandomIv) {
std::vector<uint8_t> iv;
ASSERT_TRUE(AesCryptor::GenerateRandomIv(FOURCC_cenc, &iv));
ASSERT_EQ(kCencIvSize, iv.size());
LOG(INFO) << "Random IV: " << base::HexEncode(iv.data(), iv.size());
LOG(INFO) << "Random IV: "
<< absl::BytesToHexString(std::string(iv.begin(), iv.end()));
}
TEST_F(AesCtrEncryptorTest, UnsupportedKeySize) {
@ -427,16 +427,17 @@ TEST_F(AesCbcTest, Aes128CbcPkcs5) {
const std::string kPlaintext =
"Plain text with a g-clef U+1D11E \360\235\204\236";
const std::string kExpectedCiphertextHex =
"D4A67A0BA33C30F207344D81D1E944BBE65587C3D7D9939A"
"C070C62B9C15A3EA312EA4AD1BC7929F4D3C16B03AD5ADA8";
"d4a67a0ba33c30f207344d81d1e944bbe65587c3d7d9939a"
"c070c62b9c15a3ea312ea4ad1bc7929f4d3c16b03ad5ada8";
key_.assign(kKey.begin(), kKey.end());
iv_.assign(kIv.begin(), kIv.end());
const std::vector<uint8_t> plaintext(kPlaintext.begin(), kPlaintext.end());
std::vector<uint8_t> expected_ciphertext;
ASSERT_TRUE(
base::HexStringToBytes(kExpectedCiphertextHex, &expected_ciphertext));
std::string expected_ciphertext_string =
absl::HexStringToBytes(kExpectedCiphertextHex);
std::vector<uint8_t> expected_ciphertext(expected_ciphertext_string.begin(),
expected_ciphertext_string.end());
TestEncryptDecrypt(plaintext, expected_ciphertext);
}
@ -444,15 +445,16 @@ TEST_F(AesCbcTest, Aes192CbcPkcs5) {
const std::string kKey = "192bitsIsTwentyFourByte!";
const std::string kIv = "Sweet Sixteen IV";
const std::string kPlaintext = "Small text";
const std::string kExpectedCiphertextHex = "78DE5D7C2714FC5C61346C5416F6C89A";
const std::string kExpectedCiphertextHex = "78de5d7c2714fc5c61346c5416f6c89a";
key_.assign(kKey.begin(), kKey.end());
iv_.assign(kIv.begin(), kIv.end());
const std::vector<uint8_t> plaintext(kPlaintext.begin(), kPlaintext.end());
std::vector<uint8_t> expected_ciphertext;
ASSERT_TRUE(
base::HexStringToBytes(kExpectedCiphertextHex, &expected_ciphertext));
std::string expected_ciphertext_string =
absl::HexStringToBytes(kExpectedCiphertextHex);
std::vector<uint8_t> expected_ciphertext(expected_ciphertext_string.begin(),
expected_ciphertext_string.end());
TestEncryptDecrypt(plaintext, expected_ciphertext);
}
@ -557,10 +559,27 @@ TEST_F(AesCbcTest, Pkcs5CipherTextEmpty) {
EXPECT_FALSE(decryptor_->Crypt("", &plaintext));
}
std::ostream& operator<<(std::ostream& os, CbcPaddingScheme scheme) {
switch (scheme) {
case kNoPadding:
return os << "kNoPadding";
case kPkcs5Padding:
return os << "kPkcs5Padding";
case kCtsPadding:
return os << "kCtsPadding";
default:
return os << "Unrecognized scheme: " << scheme;
}
}
struct CbcTestCase {
CbcPaddingScheme padding_scheme;
const char* plaintext_hex;
const char* expected_ciphertext_hex;
friend std::ostream& operator<<(std::ostream& os, const CbcTestCase& param) {
return os << "padding_scheme = " << param.padding_scheme
<< ", plaintext = " << param.plaintext_hex;
}
};
const CbcTestCase kCbcTestCases[] = {
@ -608,14 +627,17 @@ TEST_P(AesCbcCryptorVerificationTest, EncryptDecryptTest) {
std::vector<uint8_t> plaintext;
std::string plaintext_hex(GetParam().plaintext_hex);
if (!plaintext_hex.empty()) {
ASSERT_TRUE(base::HexStringToBytes(plaintext_hex, &plaintext));
std::string plaintext_string = absl::HexStringToBytes(plaintext_hex);
plaintext.assign(plaintext_string.begin(), plaintext_string.end());
}
std::vector<uint8_t> expected_ciphertext;
std::string expected_ciphertext_hex(GetParam().expected_ciphertext_hex);
if (!expected_ciphertext_hex.empty()) {
ASSERT_TRUE(base::HexStringToBytes(GetParam().expected_ciphertext_hex,
&expected_ciphertext));
std::string expected_ciphertext_string =
absl::HexStringToBytes(expected_ciphertext_hex);
expected_ciphertext.assign(expected_ciphertext_string.begin(),
expected_ciphertext_string.end());
}
TestEncryptDecrypt(plaintext, expected_ciphertext);

View File

@ -6,18 +6,9 @@
#include "packager/media/base/aes_decryptor.h"
#include <openssl/aes.h>
#include <algorithm>
#include "packager/base/logging.h"
namespace {
// AES defines three key sizes: 128, 192 and 256 bits.
bool IsKeySizeValidForAes(size_t key_size) {
return key_size == 16 || key_size == 24 || key_size == 32;
}
} // namespace
#include "glog/logging.h"
namespace shaka {
namespace media {
@ -39,13 +30,17 @@ AesCbcDecryptor::~AesCbcDecryptor() {}
bool AesCbcDecryptor::InitializeWithIv(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv) {
if (!IsKeySizeValidForAes(key.size())) {
LOG(ERROR) << "Invalid AES key size: " << key.size();
if (!SetupCipher(key.size(), kCbcMode)) {
return false;
}
if (mbedtls_cipher_setkey(&cipher_ctx_, key.data(),
static_cast<int>(8 * key.size()),
MBEDTLS_DECRYPT) != 0) {
LOG(ERROR) << "Failed to set CBC decryption key";
return false;
}
CHECK_EQ(AES_set_decrypt_key(key.data(), key.size() * 8, mutable_aes_key()),
0);
return SetIv(iv);
}
@ -54,17 +49,20 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
uint8_t* plaintext,
size_t* plaintext_size) {
DCHECK(plaintext_size);
DCHECK(aes_key());
// Plaintext size is the same as ciphertext size except for pkcs5 padding.
// Will update later if using pkcs5 padding. For pkcs5 padding, we still
// need at least |ciphertext_size| bytes for intermediate operation.
if (*plaintext_size < ciphertext_size) {
LOG(ERROR) << "Expecting output size of at least " << ciphertext_size
<< " bytes.";
// mbedtls requires a buffer large enough for one extra block.
const size_t required_plaintext_size = ciphertext_size + AES_BLOCK_SIZE;
if (*plaintext_size < required_plaintext_size) {
LOG(ERROR) << "Expecting output size of at least "
<< required_plaintext_size << " bytes.";
return false;
}
*plaintext_size = ciphertext_size;
*plaintext_size = required_plaintext_size - AES_BLOCK_SIZE;
// If the ciphertext size is 0, this can be a no-op decrypt, so long as the
// padding mode isn't PKCS5.
if (ciphertext_size == 0) {
if (padding_scheme_ == kPkcs5Padding) {
LOG(ERROR) << "Expected ciphertext to be at least " << AES_BLOCK_SIZE
@ -77,9 +75,15 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
const size_t residual_block_size = ciphertext_size % AES_BLOCK_SIZE;
const size_t cbc_size = ciphertext_size - residual_block_size;
// Copy the residual block early, since mbedtls may overwrite one extra block
// of the output, and input and output may be the same buffer.
std::vector<uint8_t> residual_block(ciphertext + cbc_size,
ciphertext + ciphertext_size);
DCHECK_EQ(residual_block.size(), residual_block_size);
if (residual_block_size == 0) {
AES_cbc_encrypt(ciphertext, plaintext, ciphertext_size, aes_key(),
internal_iv_.data(), AES_DECRYPT);
CbcDecryptBlocks(ciphertext, ciphertext_size, plaintext);
if (padding_scheme_ != kPkcs5Padding)
return true;
@ -93,11 +97,10 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
*plaintext_size -= num_padding_bytes;
return true;
} else if (padding_scheme_ == kNoPadding) {
AES_cbc_encrypt(ciphertext, plaintext, cbc_size, aes_key(),
internal_iv_.data(), AES_DECRYPT);
CbcDecryptBlocks(ciphertext, cbc_size, plaintext);
// The residual block is not encrypted.
memcpy(plaintext + cbc_size, ciphertext + cbc_size, residual_block_size);
memcpy(plaintext + cbc_size, residual_block.data(), residual_block_size);
return true;
} else if (padding_scheme_ != kCtsPadding) {
LOG(ERROR) << "Expecting cipher text size to be multiple of "
@ -112,44 +115,49 @@ bool AesCbcDecryptor::CryptInternal(const uint8_t* ciphertext,
return true;
}
// Copy the next-to-last block early, since mbedtls may overwrite one extra
// block of the output, and input and output may be the same buffer.
// NOTE: Before this point, there may not be such a block. Here, we know
// this is safe.
std::vector<uint8_t> next_to_last_block(
ciphertext + cbc_size - AES_BLOCK_SIZE, ciphertext + cbc_size);
// AES-CBC decrypt everything up to the next-to-last full block.
if (cbc_size > AES_BLOCK_SIZE) {
AES_cbc_encrypt(ciphertext, plaintext, cbc_size - AES_BLOCK_SIZE, aes_key(),
internal_iv_.data(), AES_DECRYPT);
CbcDecryptBlocks(ciphertext, cbc_size - AES_BLOCK_SIZE, plaintext);
}
const uint8_t* next_to_last_ciphertext_block =
ciphertext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;
uint8_t* next_to_last_plaintext_block =
plaintext + ciphertext_size - residual_block_size - AES_BLOCK_SIZE;
uint8_t* next_to_last_plaintext_block = plaintext + cbc_size - AES_BLOCK_SIZE;
// Determine what the last IV should be so that we can "skip ahead" in the
// CBC decryption.
std::vector<uint8_t> last_iv(
ciphertext + ciphertext_size - residual_block_size,
ciphertext + ciphertext_size);
last_iv.resize(AES_BLOCK_SIZE, 0);
// The next-to-last block should be decrypted first in ECB mode, which is
// effectively what you get with an IV of all zeroes.
std::vector<uint8_t> backup_iv(internal_iv_);
internal_iv_.assign(AES_BLOCK_SIZE, 0);
// mbedtls requires a buffer large enough for one extra block.
std::vector<uint8_t> stolen_bits(AES_BLOCK_SIZE * 2);
CbcDecryptBlocks(next_to_last_block.data(), AES_BLOCK_SIZE,
stolen_bits.data());
// Decrypt the next-to-last block using the IV determined above. This decrypts
// the residual block bits.
AES_cbc_encrypt(next_to_last_ciphertext_block, next_to_last_plaintext_block,
AES_BLOCK_SIZE, aes_key(), last_iv.data(), AES_DECRYPT);
// Reconstruct the final two blocks of ciphertext.
std::vector<uint8_t> reconstructed_blocks(AES_BLOCK_SIZE * 2);
memcpy(reconstructed_blocks.data(), residual_block.data(),
residual_block_size);
memcpy(reconstructed_blocks.data() + residual_block_size,
stolen_bits.data() + residual_block_size,
AES_BLOCK_SIZE - residual_block_size);
memcpy(reconstructed_blocks.data() + AES_BLOCK_SIZE,
next_to_last_block.data(), AES_BLOCK_SIZE);
// Swap back the residual block bits and the next-to-last block.
if (plaintext == ciphertext) {
std::swap_ranges(next_to_last_plaintext_block,
next_to_last_plaintext_block + residual_block_size,
next_to_last_plaintext_block + AES_BLOCK_SIZE);
} else {
memcpy(next_to_last_plaintext_block + AES_BLOCK_SIZE,
next_to_last_plaintext_block, residual_block_size);
memcpy(next_to_last_plaintext_block,
next_to_last_ciphertext_block + AES_BLOCK_SIZE, residual_block_size);
}
// Decrypt the last two blocks.
internal_iv_ = backup_iv;
// mbedtls requires a buffer large enough for one extra block.
std::vector<uint8_t> final_output_blocks(AES_BLOCK_SIZE * 3);
CbcDecryptBlocks(reconstructed_blocks.data(), AES_BLOCK_SIZE * 2,
final_output_blocks.data());
// Decrypt the next-to-last full block.
AES_cbc_encrypt(next_to_last_plaintext_block, next_to_last_plaintext_block,
AES_BLOCK_SIZE, aes_key(), internal_iv_.data(), AES_DECRYPT);
// Copy the final output.
memcpy(next_to_last_plaintext_block, final_output_blocks.data(),
AES_BLOCK_SIZE + residual_block_size);
return true;
}
@ -158,5 +166,27 @@ void AesCbcDecryptor::SetIvInternal() {
internal_iv_.resize(AES_BLOCK_SIZE, 0);
}
void AesCbcDecryptor::CbcDecryptBlocks(const uint8_t* ciphertext,
size_t ciphertext_size,
uint8_t* plaintext) {
CHECK_EQ(ciphertext_size % AES_BLOCK_SIZE, 0u);
CHECK_GT(ciphertext_size, 0u);
// Copy the final block of ciphertext before decryption, since we could be
// decrypting in-place.
const uint8_t* last_block = ciphertext + ciphertext_size - AES_BLOCK_SIZE;
std::vector<uint8_t> next_iv(last_block, last_block + AES_BLOCK_SIZE);
size_t output_size = 0;
CHECK_EQ(mbedtls_cipher_crypt(&cipher_ctx_, internal_iv_.data(),
AES_BLOCK_SIZE, ciphertext, ciphertext_size,
plaintext, &output_size),
0);
DCHECK_EQ(output_size % AES_BLOCK_SIZE, 0u);
// Update the internal IV.
internal_iv_ = next_iv;
}
} // namespace media
} // namespace shaka

View File

@ -4,14 +4,14 @@
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
//
// AES Decryptor implementation using openssl.
// AES Decryptor implementation using mbedtls.
#ifndef PACKAGER_MEDIA_BASE_AES_DECRYPTOR_H_
#define PACKAGER_MEDIA_BASE_AES_DECRYPTOR_H_
#include <vector>
#include "packager/base/macros.h"
#include "packager/macros.h"
#include "packager/media/base/aes_cryptor.h"
#include "packager/media/base/aes_encryptor.h"
@ -54,6 +54,9 @@ class AesCbcDecryptor : public AesCryptor {
uint8_t* plaintext,
size_t* plaintext_size) override;
void SetIvInternal() override;
void CbcDecryptBlocks(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext);
const CbcPaddingScheme padding_scheme_;
// 16-byte internal iv for crypto operations.

View File

@ -6,9 +6,7 @@
#include "packager/media/base/aes_encryptor.h"
#include <openssl/aes.h>
#include "packager/base/logging.h"
#include "glog/logging.h"
namespace {
@ -22,41 +20,36 @@ bool Increment64(uint8_t* counter) {
return true;
}
// AES defines three key sizes: 128, 192 and 256 bits.
bool IsKeySizeValidForAes(size_t key_size) {
return key_size == 16 || key_size == 24 || key_size == 32;
}
} // namespace
namespace shaka {
namespace media {
AesEncryptor::AesEncryptor(ConstantIvFlag constant_iv_flag)
: AesCryptor(constant_iv_flag) {}
AesEncryptor::~AesEncryptor() {}
bool AesEncryptor::InitializeWithIv(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv) {
if (!IsKeySizeValidForAes(key.size())) {
LOG(ERROR) << "Invalid AES key size: " << key.size();
return false;
}
CHECK_EQ(AES_set_encrypt_key(key.data(), key.size() * 8, mutable_aes_key()),
0);
return SetIv(iv);
}
// We don't support constant iv for counter mode, as we don't have a use case
// for that.
AesCtrEncryptor::AesCtrEncryptor()
: AesEncryptor(kDontUseConstantIv),
: AesCryptor(kDontUseConstantIv),
block_offset_(0),
encrypted_counter_(AES_BLOCK_SIZE, 0) {}
// mbedtls requires an extra output block.
encrypted_counter_(AES_BLOCK_SIZE * 2, 0) {}
AesCtrEncryptor::~AesCtrEncryptor() {}
bool AesCtrEncryptor::InitializeWithIv(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv) {
if (!SetupCipher(key.size(), kCtrMode)) {
return false;
}
if (mbedtls_cipher_setkey(&cipher_ctx_, key.data(),
static_cast<int>(8 * key.size()),
MBEDTLS_ENCRYPT) != 0) {
LOG(ERROR) << "Failed to set CTR encryption key";
return false;
}
return SetIv(iv);
}
bool AesCtrEncryptor::CryptInternal(const uint8_t* plaintext,
size_t plaintext_size,
@ -64,7 +57,6 @@ bool AesCtrEncryptor::CryptInternal(const uint8_t* plaintext,
size_t* ciphertext_size) {
DCHECK(plaintext);
DCHECK(ciphertext);
DCHECK(aes_key());
// |ciphertext_size| is always the same as |plaintext_size| for counter mode.
if (*ciphertext_size < plaintext_size) {
@ -76,7 +68,13 @@ bool AesCtrEncryptor::CryptInternal(const uint8_t* plaintext,
for (size_t i = 0; i < plaintext_size; ++i) {
if (block_offset_ == 0) {
AES_encrypt(&counter_[0], &encrypted_counter_[0], aes_key());
size_t ignored_output_size;
CHECK_EQ(
mbedtls_cipher_crypt(&cipher_ctx_, /* iv= */ NULL, /* iv_len= */ 0,
&counter_[0], AES_BLOCK_SIZE,
&encrypted_counter_[0], &ignored_output_size),
0);
// As mentioned in ISO/IEC 23001-7:2016 CENC spec, of the 16 byte counter
// block, bytes 8 to 15 (i.e. the least significant bytes) are used as a
// simple 64 bit unsigned integer that is incremented by one for each
@ -101,7 +99,7 @@ AesCbcEncryptor::AesCbcEncryptor(CbcPaddingScheme padding_scheme)
AesCbcEncryptor::AesCbcEncryptor(CbcPaddingScheme padding_scheme,
ConstantIvFlag constant_iv_flag)
: AesEncryptor(constant_iv_flag), padding_scheme_(padding_scheme) {
: AesCryptor(constant_iv_flag), padding_scheme_(padding_scheme) {
if (padding_scheme_ != kNoPadding) {
CHECK_EQ(constant_iv_flag, kUseConstantIv)
<< "non-constant iv (cipher block chain across calls) only makes sense "
@ -111,27 +109,49 @@ AesCbcEncryptor::AesCbcEncryptor(CbcPaddingScheme padding_scheme,
AesCbcEncryptor::~AesCbcEncryptor() {}
bool AesCbcEncryptor::InitializeWithIv(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv) {
if (!SetupCipher(key.size(), kCbcMode)) {
return false;
}
if (mbedtls_cipher_setkey(&cipher_ctx_, key.data(),
static_cast<int>(8 * key.size()),
MBEDTLS_ENCRYPT) != 0) {
LOG(ERROR) << "Failed to set CBC encryption key";
return false;
}
return SetIv(iv);
}
bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext,
size_t* ciphertext_size) {
DCHECK(aes_key());
const size_t residual_block_size = plaintext_size % AES_BLOCK_SIZE;
const size_t num_padding_bytes = NumPaddingBytes(plaintext_size);
const size_t required_ciphertext_size = plaintext_size + num_padding_bytes;
// mbedtls requires a buffer large enough for one extra block.
const size_t required_ciphertext_size =
plaintext_size + num_padding_bytes + AES_BLOCK_SIZE;
if (*ciphertext_size < required_ciphertext_size) {
LOG(ERROR) << "Expecting output size of at least "
<< required_ciphertext_size << " bytes.";
return false;
}
*ciphertext_size = required_ciphertext_size;
*ciphertext_size = required_ciphertext_size - AES_BLOCK_SIZE;
// Encrypt everything but the residual block using CBC.
const size_t cbc_size = plaintext_size - residual_block_size;
// Copy the residual block early, since mbedtls may overwrite one extra block
// of the output, and input and output may be the same buffer.
std::vector<uint8_t> residual_block(plaintext + cbc_size,
plaintext + plaintext_size);
DCHECK_EQ(residual_block.size(), residual_block_size);
if (cbc_size != 0) {
AES_cbc_encrypt(plaintext, ciphertext, cbc_size, aes_key(),
internal_iv_.data(), AES_ENCRYPT);
CbcEncryptBlocks(plaintext, cbc_size, ciphertext);
} else if (padding_scheme_ == kCtsPadding) {
// Don't have a full block, leave unencrypted.
memcpy(ciphertext, plaintext, plaintext_size);
@ -148,38 +168,39 @@ bool AesCbcEncryptor::CryptInternal(const uint8_t* plaintext,
return true;
}
std::vector<uint8_t> residual_block(plaintext + cbc_size,
plaintext + plaintext_size);
DCHECK_EQ(residual_block.size(), residual_block_size);
uint8_t* residual_ciphertext_block = ciphertext + cbc_size;
if (padding_scheme_ == kPkcs5Padding) {
DCHECK_EQ(num_padding_bytes, AES_BLOCK_SIZE - residual_block_size);
// Pad residue block with PKCS5 padding.
residual_block.resize(AES_BLOCK_SIZE, static_cast<char>(num_padding_bytes));
AES_cbc_encrypt(residual_block.data(), residual_ciphertext_block,
AES_BLOCK_SIZE, aes_key(), internal_iv_.data(),
AES_ENCRYPT);
CbcEncryptBlocks(residual_block.data(), AES_BLOCK_SIZE,
residual_ciphertext_block);
} else {
DCHECK_EQ(num_padding_bytes, 0u);
DCHECK_EQ(padding_scheme_, kCtsPadding);
// Zero-pad the residual block and encrypt using CBC.
residual_block.resize(AES_BLOCK_SIZE, 0);
AES_cbc_encrypt(residual_block.data(), residual_block.data(),
AES_BLOCK_SIZE, aes_key(), internal_iv_.data(),
AES_ENCRYPT);
// mbedtls requires an extra block in the output buffer, and it cannot be
// the same as the input buffer.
std::vector<uint8_t> encrypted_residual_block(AES_BLOCK_SIZE * 2);
CbcEncryptBlocks(residual_block.data(), AES_BLOCK_SIZE,
encrypted_residual_block.data());
// Replace the last full block with the zero-padded, encrypted residual
// block, and replace the residual block with the equivalent portion of the
// last full encrypted block. It may appear that some encrypted bits of the
// last full block are lost, but they are not, as they were used as the IV
// when encrypting the zero-padded residual block.
// This ordering of the output is described as "CS2" in literature.
// https://en.wikipedia.org/wiki/Ciphertext_stealing#CS2
memcpy(residual_ciphertext_block,
residual_ciphertext_block - AES_BLOCK_SIZE, residual_block_size);
memcpy(residual_ciphertext_block - AES_BLOCK_SIZE, residual_block.data(),
AES_BLOCK_SIZE);
memcpy(residual_ciphertext_block - AES_BLOCK_SIZE,
encrypted_residual_block.data(), AES_BLOCK_SIZE);
}
return true;
}
@ -195,5 +216,23 @@ size_t AesCbcEncryptor::NumPaddingBytes(size_t size) const {
: 0;
}
void AesCbcEncryptor::CbcEncryptBlocks(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext) {
CHECK_EQ(plaintext_size % AES_BLOCK_SIZE, 0u);
size_t output_size = 0;
CHECK_EQ(
mbedtls_cipher_crypt(&cipher_ctx_, internal_iv_.data(), AES_BLOCK_SIZE,
plaintext, plaintext_size, ciphertext, &output_size),
0);
CHECK_EQ(output_size % AES_BLOCK_SIZE, 0u);
CHECK_GT(output_size, 0u);
uint8_t* last_block = ciphertext + output_size - AES_BLOCK_SIZE;
internal_iv_.assign(last_block, last_block + AES_BLOCK_SIZE);
}
} // namespace media
} // namespace shaka

View File

@ -4,7 +4,7 @@
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
//
// AES Encryptor implementation using openssl.
// AES Encryptor implementation using mbedtls.
#ifndef PACKAGER_MEDIA_BASE_AES_ENCRYPTOR_H_
#define PACKAGER_MEDIA_BASE_AES_ENCRYPTOR_H_
@ -12,40 +12,25 @@
#include <string>
#include <vector>
#include "packager/base/macros.h"
#include "packager/macros.h"
#include "packager/media/base/aes_cryptor.h"
namespace shaka {
namespace media {
class AesEncryptor : public AesCryptor {
public:
/// @param constant_iv_flag indicates whether a constant iv is used,
/// kUseConstantIv means that the same iv is used for all Crypt calls
/// until iv is changed via SetIv; otherwise, iv can be incremented
/// (for counter mode) or chained (for cipher block chaining mode)
/// internally inside Crypt call, i.e. iv will be updated across Crypt
/// calls.
explicit AesEncryptor(ConstantIvFlag constant_iv_flag);
~AesEncryptor() override;
/// Initialize the encryptor with specified key and IV.
/// @return true on successful initialization, false otherwise.
bool InitializeWithIv(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv) override;
private:
DISALLOW_COPY_AND_ASSIGN(AesEncryptor);
};
// Class which implements AES-CTR counter-mode encryption.
class AesCtrEncryptor : public AesEncryptor {
class AesCtrEncryptor : public AesCryptor {
public:
AesCtrEncryptor();
~AesCtrEncryptor() override;
uint32_t block_offset() const { return block_offset_; }
/// Initialize the encryptor with specified key and IV.
/// @return true on successful initialization, false otherwise.
bool InitializeWithIv(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv) override;
private:
bool CryptInternal(const uint8_t* plaintext,
size_t plaintext_size,
@ -74,7 +59,7 @@ enum CbcPaddingScheme {
};
// Class which implements AES-CBC (Cipher block chaining) encryption.
class AesCbcEncryptor : public AesEncryptor {
class AesCbcEncryptor : public AesCryptor {
public:
/// Creates a AesCbcEncryptor with continous cipher block chain across Crypt
/// calls, i.e. AesCbcEncryptor(padding_scheme, kDontUseConstantIv).
@ -94,6 +79,11 @@ class AesCbcEncryptor : public AesEncryptor {
~AesCbcEncryptor() override;
/// Initialize the encryptor with specified key and IV.
/// @return true on successful initialization, false otherwise.
bool InitializeWithIv(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv) override;
private:
bool CryptInternal(const uint8_t* plaintext,
size_t plaintext_size,
@ -102,6 +92,10 @@ class AesCbcEncryptor : public AesEncryptor {
void SetIvInternal() override;
size_t NumPaddingBytes(size_t size) const override;
void CbcEncryptBlocks(const uint8_t* plaintext,
size_t plaintext_size,
uint8_t* ciphertext);
const CbcPaddingScheme padding_scheme_;
// 16-byte internal iv for crypto operations.
std::vector<uint8_t> internal_iv_;

View File

@ -6,9 +6,9 @@
#include "packager/media/base/aes_pattern_cryptor.h"
#include <openssl/aes.h>
#include <algorithm>
#include "packager/base/logging.h"
#include "glog/logging.h"
namespace shaka {
namespace media {

View File

@ -8,7 +8,7 @@
#include <memory>
#include "packager/base/macros.h"
#include "packager/macros.h"
namespace shaka {
namespace media {

View File

@ -7,7 +7,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "packager/base/strings/string_number_conversions.h"
#include "absl/strings/escaping.h"
#include "packager/media/base/aes_pattern_cryptor.h"
#include "packager/media/base/mock_aes_cryptor.h"
@ -59,13 +59,16 @@ TEST_P(AesPatternCryptorVerificationTest, PatternTest) {
std::vector<uint8_t> text;
std::string text_hex(GetParam().text_hex);
if (!text_hex.empty()) {
ASSERT_TRUE(base::HexStringToBytes(text_hex, &text));
std::string text_string = absl::HexStringToBytes(text_hex);
text.assign(text_string.begin(), text_string.end());
}
std::vector<uint8_t> expected_crypt_text;
std::string expected_crypt_text_hex(GetParam().expected_crypt_text_hex);
if (!expected_crypt_text_hex.empty()) {
ASSERT_TRUE(
base::HexStringToBytes(expected_crypt_text_hex, &expected_crypt_text));
std::string expected_crypt_text_string =
absl::HexStringToBytes(expected_crypt_text_hex);
expected_crypt_text.assign(expected_crypt_text_string.begin(),
expected_crypt_text_string.end());
}
ON_CALL(*mock_cryptor_, CryptInternal(_, _, _, _))

View File

@ -18,66 +18,43 @@
#include "packager/media/base/rsa_key.h"
#include <openssl/err.h>
#include <openssl/rsa.h>
#include <openssl/x509.h>
#include <memory>
#include <vector>
#include "packager/base/logging.h"
#include "packager/base/sha1.h"
#include "glog/logging.h"
#include "mbedtls/error.h"
#include "mbedtls/md.h"
namespace {
const size_t kPssSaltLength = 20u;
// Serialize rsa key from DER encoded PKCS#1 RSAPrivateKey.
RSA* DeserializeRsaKey(const std::string& serialized_key,
bool deserialize_private_key) {
if (serialized_key.empty()) {
LOG(ERROR) << "Serialized RSA Key is empty.";
return NULL;
std::string mbedtls_strerr(int rv) {
// There is always a "high level" error.
std::string output(mbedtls_high_level_strerr(rv));
// Some errors have a "low level" error, which is like an inner error code
// with a deeper explanation. But on mac and Windows, ostream crashes if you
// give it NULL. So we combine them ourselves with a NULL check.
const char* low_level_error = mbedtls_low_level_strerr(rv);
if (low_level_error) {
output += ": ";
output += low_level_error;
}
BIO* bio = BIO_new_mem_buf(const_cast<char*>(serialized_key.data()),
serialized_key.size());
if (bio == NULL) {
LOG(ERROR) << "BIO_new_mem_buf returned NULL.";
return NULL;
}
RSA* rsa_key = deserialize_private_key ? d2i_RSAPrivateKey_bio(bio, NULL)
: d2i_RSAPublicKey_bio(bio, NULL);
BIO_free(bio);
return rsa_key;
return output;
}
RSA* DeserializeRsaPrivateKey(const std::string& serialized_key) {
RSA* rsa_key = DeserializeRsaKey(serialized_key, true);
if (!rsa_key) {
LOG(ERROR) << "Private RSA key deserialization failure.";
return NULL;
}
if (RSA_check_key(rsa_key) != 1) {
LOG(ERROR) << "Invalid RSA Private key: " << ERR_error_string(
ERR_get_error(), NULL);
RSA_free(rsa_key);
return NULL;
}
return rsa_key;
}
std::string sha1(const std::string& message) {
const mbedtls_md_info_t* md_info = mbedtls_md_info_from_type(MBEDTLS_MD_SHA1);
DCHECK(md_info);
RSA* DeserializeRsaPublicKey(const std::string& serialized_key) {
RSA* rsa_key = DeserializeRsaKey(serialized_key, false);
if (!rsa_key) {
LOG(ERROR) << "Private RSA key deserialization failure.";
return NULL;
}
if (RSA_size(rsa_key) <= 0) {
LOG(ERROR) << "Invalid RSA Public key: " << ERR_error_string(
ERR_get_error(), NULL);
RSA_free(rsa_key);
return NULL;
}
return rsa_key;
std::string hash(mbedtls_md_get_size(md_info), 0);
CHECK_EQ(0,
mbedtls_md(md_info, reinterpret_cast<const uint8_t*>(message.data()),
message.size(), reinterpret_cast<uint8_t*>(hash.data())));
return hash;
}
} // namespace
@ -85,39 +62,83 @@ RSA* DeserializeRsaPublicKey(const std::string& serialized_key) {
namespace shaka {
namespace media {
RsaPrivateKey::RsaPrivateKey(RSA* rsa_key) : rsa_key_(rsa_key) {
DCHECK(rsa_key);
RsaPrivateKey::RsaPrivateKey() {
mbedtls_pk_init(&pk_context_);
mbedtls_entropy_init(&entropy_context_);
mbedtls_ctr_drbg_init(&prng_context_);
}
RsaPrivateKey::~RsaPrivateKey() {
if (rsa_key_ != NULL)
RSA_free(rsa_key_);
mbedtls_pk_free(&pk_context_);
mbedtls_entropy_free(&entropy_context_);
mbedtls_ctr_drbg_free(&prng_context_);
}
RsaPrivateKey* RsaPrivateKey::Create(const std::string& serialized_key) {
RSA* rsa_key = DeserializeRsaPrivateKey(serialized_key);
return rsa_key == NULL ? NULL : new RsaPrivateKey(rsa_key);
std::unique_ptr<RsaPrivateKey> key(new RsaPrivateKey());
if (!key->Deserialize(serialized_key)) {
return NULL;
}
return key.release();
}
bool RsaPrivateKey::Deserialize(const std::string& serialized_key) {
const mbedtls_pk_info_t* pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
DCHECK(pk_info);
CHECK_EQ(mbedtls_ctr_drbg_seed(&prng_context_, mbedtls_entropy_func,
&entropy_context_, /* custom= */ NULL,
/* custom_len= */ 0),
0);
int rv = mbedtls_pk_parse_key(
&pk_context_, reinterpret_cast<const uint8_t*>(serialized_key.data()),
serialized_key.size(),
/* password= */ NULL,
/* password_len= */ 0, mbedtls_ctr_drbg_random, &prng_context_);
if (rv != 0) {
LOG(ERROR) << "RSA private key failed to load: " << mbedtls_strerr(rv);
return false;
}
// Set the padding mode and digest mode.
mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
rv = mbedtls_rsa_set_padding(rsa_context, MBEDTLS_RSA_PKCS_V21,
MBEDTLS_MD_SHA1);
if (rv != 0) {
LOG(ERROR) << "RSA private key failed to set padding: "
<< mbedtls_strerr(rv);
return false;
}
return true;
}
bool RsaPrivateKey::Decrypt(const std::string& encrypted_message,
std::string* decrypted_message) {
DCHECK(decrypted_message);
size_t rsa_size = RSA_size(rsa_key_);
mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
size_t rsa_size = mbedtls_rsa_get_len(rsa_context);
if (encrypted_message.size() != rsa_size) {
LOG(ERROR) << "Encrypted RSA message has the wrong size (expected "
<< rsa_size << ", actual " << encrypted_message.size() << ").";
return false;
}
decrypted_message->resize(encrypted_message.size());
decrypted_message->resize(rsa_size);
int decrypted_size = RSA_private_decrypt(
rsa_size, reinterpret_cast<const uint8_t*>(encrypted_message.data()),
reinterpret_cast<uint8_t*>(&(*decrypted_message)[0]), rsa_key_,
RSA_PKCS1_OAEP_PADDING);
size_t decrypted_size = 0;
int rv = mbedtls_rsa_rsaes_oaep_decrypt(
rsa_context, mbedtls_ctr_drbg_random, &prng_context_,
/* label= */ NULL,
/* label_len= */ 0, &decrypted_size,
reinterpret_cast<const uint8_t*>(encrypted_message.data()),
reinterpret_cast<uint8_t*>(decrypted_message->data()),
decrypted_message->size());
if (decrypted_size == -1) {
LOG(ERROR) << "RSA private decrypt failure: " << ERR_error_string(
ERR_get_error(), NULL);
if (rv != 0) {
LOG(ERROR) << "RSA private decrypt failure: " << mbedtls_strerr(rv);
return false;
}
decrypted_message->resize(decrypted_size);
@ -132,45 +153,73 @@ bool RsaPrivateKey::GenerateSignature(const std::string& message,
return false;
}
std::string message_digest = base::SHA1HashString(message);
mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
// Add PSS padding.
size_t rsa_size = RSA_size(rsa_key_);
std::vector<uint8_t> padded_digest(rsa_size);
if (!RSA_padding_add_PKCS1_PSS_mgf1(
rsa_key_, &padded_digest[0],
reinterpret_cast<uint8_t*>(&message_digest[0]), EVP_sha1(),
EVP_sha1(), kPssSaltLength)) {
LOG(ERROR) << "RSA padding failure: " << ERR_error_string(ERR_get_error(),
NULL);
return false;
}
// Encrypt PSS padded digest.
size_t rsa_size = mbedtls_rsa_get_len(rsa_context);
signature->resize(rsa_size);
int signature_size = RSA_private_encrypt(
padded_digest.size(), &padded_digest[0],
reinterpret_cast<uint8_t*>(&(*signature)[0]), rsa_key_, RSA_NO_PADDING);
if (signature_size != static_cast<int>(rsa_size)) {
LOG(ERROR) << "RSA private encrypt failure: " << ERR_error_string(
ERR_get_error(), NULL);
std::string hash = sha1(message);
int rv = mbedtls_rsa_rsassa_pss_sign_ext(
rsa_context, mbedtls_ctr_drbg_random, &prng_context_, MBEDTLS_MD_SHA1,
static_cast<unsigned int>(hash.size()),
reinterpret_cast<const uint8_t*>(hash.data()), kPssSaltLength,
reinterpret_cast<uint8_t*>(signature->data()));
if (rv != 0) {
LOG(ERROR) << "RSA sign failure: " << mbedtls_strerr(rv);
return false;
}
return true;
}
RsaPublicKey::RsaPublicKey(RSA* rsa_key) : rsa_key_(rsa_key) {
DCHECK(rsa_key);
RsaPublicKey::RsaPublicKey() {
mbedtls_pk_init(&pk_context_);
mbedtls_entropy_init(&entropy_context_);
mbedtls_ctr_drbg_init(&prng_context_);
}
RsaPublicKey::~RsaPublicKey() {
if (rsa_key_ != NULL)
RSA_free(rsa_key_);
mbedtls_pk_free(&pk_context_);
mbedtls_entropy_free(&entropy_context_);
mbedtls_ctr_drbg_free(&prng_context_);
}
RsaPublicKey* RsaPublicKey::Create(const std::string& serialized_key) {
RSA* rsa_key = DeserializeRsaPublicKey(serialized_key);
return rsa_key == NULL ? NULL : new RsaPublicKey(rsa_key);
std::unique_ptr<RsaPublicKey> key(new RsaPublicKey());
if (!key->Deserialize(serialized_key)) {
return NULL;
}
return key.release();
}
bool RsaPublicKey::Deserialize(const std::string& serialized_key) {
const mbedtls_pk_info_t* pk_info = mbedtls_pk_info_from_type(MBEDTLS_PK_RSA);
DCHECK(pk_info);
CHECK_EQ(mbedtls_ctr_drbg_seed(&prng_context_, mbedtls_entropy_func,
&entropy_context_, /* custom= */ NULL,
/* custom_len= */ 0),
0);
int rv = mbedtls_pk_parse_public_key(
&pk_context_, reinterpret_cast<const uint8_t*>(serialized_key.data()),
serialized_key.size());
if (rv != 0) {
LOG(ERROR) << "RSA public key failed to load: " << mbedtls_strerr(rv);
return false;
}
// Set the padding mode and digest mode.
mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
rv = mbedtls_rsa_set_padding(rsa_context, MBEDTLS_RSA_PKCS_V21,
MBEDTLS_MD_SHA1);
if (rv != 0) {
LOG(ERROR) << "RSA public key failed to set padding: "
<< mbedtls_strerr(rv);
return false;
}
return true;
}
bool RsaPublicKey::Encrypt(const std::string& clear_message,
@ -181,17 +230,20 @@ bool RsaPublicKey::Encrypt(const std::string& clear_message,
return false;
}
size_t rsa_size = RSA_size(rsa_key_);
encrypted_message->resize(rsa_size);
int encrypted_size =
RSA_public_encrypt(clear_message.size(),
reinterpret_cast<const uint8_t*>(clear_message.data()),
reinterpret_cast<uint8_t*>(&(*encrypted_message)[0]),
rsa_key_, RSA_PKCS1_OAEP_PADDING);
mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
if (encrypted_size != static_cast<int>(rsa_size)) {
LOG(ERROR) << "RSA public encrypt failure: " << ERR_error_string(
ERR_get_error(), NULL);
size_t rsa_size = mbedtls_rsa_get_len(rsa_context);
encrypted_message->resize(rsa_size);
int rv = mbedtls_rsa_rsaes_oaep_encrypt(
rsa_context, mbedtls_ctr_drbg_random, &prng_context_,
/* label= */ NULL,
/* label_len= */ 0, clear_message.size(),
reinterpret_cast<const uint8_t*>(clear_message.data()),
reinterpret_cast<uint8_t*>(encrypted_message->data()));
if (rv != 0) {
LOG(ERROR) << "RSA public encrypt failure: " << mbedtls_strerr(rv);
return false;
}
return true;
@ -204,38 +256,27 @@ bool RsaPublicKey::VerifySignature(const std::string& message,
return false;
}
size_t rsa_size = RSA_size(rsa_key_);
mbedtls_rsa_context* rsa_context = mbedtls_pk_rsa(pk_context_);
size_t rsa_size = mbedtls_rsa_get_len(rsa_context);
if (signature.size() != rsa_size) {
LOG(ERROR) << "Message signature is of the wrong size (expected "
<< rsa_size << ", actual " << signature.size() << ").";
return false;
}
// Decrypt the signature.
std::vector<uint8_t> padded_digest(signature.size());
int decrypted_size =
RSA_public_decrypt(signature.size(),
reinterpret_cast<const uint8_t*>(signature.data()),
&padded_digest[0],
rsa_key_,
RSA_NO_PADDING);
// Verify the signature.
std::string hash = sha1(message);
int rv = mbedtls_rsa_rsassa_pss_verify_ext(
rsa_context, MBEDTLS_MD_SHA1, static_cast<unsigned int>(hash.size()),
reinterpret_cast<const uint8_t*>(hash.data()), MBEDTLS_MD_SHA1,
kPssSaltLength, reinterpret_cast<const uint8_t*>(signature.data()));
if (decrypted_size != static_cast<int>(rsa_size)) {
LOG(ERROR) << "RSA public decrypt failure: " << ERR_error_string(
ERR_get_error(), NULL);
if (rv != 0) {
LOG(ERROR) << "RSA signature verification failed: " << mbedtls_strerr(rv);
return false;
}
std::string message_digest = base::SHA1HashString(message);
// Verify PSS padding.
return RSA_verify_PKCS1_PSS_mgf1(
rsa_key_,
reinterpret_cast<const uint8_t*>(message_digest.data()),
EVP_sha1(),
EVP_sha1(),
&padded_digest[0],
kPssSaltLength) != 0;
return true;
}
} // namespace media

View File

@ -12,10 +12,10 @@
#include <string>
#include "packager/base/macros.h"
struct rsa_st;
typedef struct rsa_st RSA;
#include "mbedtls/ctr_drbg.h"
#include "mbedtls/entropy.h"
#include "mbedtls/pk.h"
#include "packager/macros.h"
namespace shaka {
namespace media {
@ -41,10 +41,13 @@ class RsaPrivateKey {
bool GenerateSignature(const std::string& message, std::string* signature);
private:
// RsaPrivateKey takes owership of |rsa_key|.
explicit RsaPrivateKey(RSA* rsa_key);
RsaPrivateKey();
RSA* rsa_key_; // owned
bool Deserialize(const std::string& serialized_key);
mbedtls_pk_context pk_context_;
mbedtls_entropy_context entropy_context_;
mbedtls_ctr_drbg_context prng_context_;
DISALLOW_COPY_AND_ASSIGN(RsaPrivateKey);
};
@ -70,10 +73,13 @@ class RsaPublicKey {
const std::string& signature);
private:
// RsaPublicKey takes owership of |rsa_key|.
explicit RsaPublicKey(RSA* rsa_key);
RsaPublicKey();
RSA* rsa_key_; // owned
bool Deserialize(const std::string& serialized_key);
mbedtls_pk_context pk_context_;
mbedtls_entropy_context entropy_context_;
mbedtls_ctr_drbg_context prng_context_;
DISALLOW_COPY_AND_ASSIGN(RsaPublicKey);
};

View File

@ -7,39 +7,30 @@
// Unit test for rsa_key RSA encryption and signing.
#include <gtest/gtest.h>
#include <filesystem>
#include <memory>
#include "packager/media/base/rsa_key.h"
#include "packager/media/base/test/fake_prng.h"
#include "packager/media/base/test/rsa_test_data.h"
namespace {
// BoringSSL does not support RAND_set_rand_method yet, so we cannot use fake
// prng with boringssl.
const bool kIsFakePrngSupported = false;
} // namespace
#include "glog/logging.h"
#include "packager/media/base/rsa_key.h"
#include "packager/media/base/test/rsa_test_data.h"
#include "packager/media/test/test_data_util.h"
namespace shaka {
namespace media {
namespace {
class RsaKeyTest : public ::testing::TestWithParam<RsaTestSet> {
public:
RsaKeyTest() : test_set_(GetParam()) {}
void SetUp() override {
if (kIsFakePrngSupported) {
// Make OpenSSL RSA deterministic.
ASSERT_TRUE(fake_prng::StartFakePrng());
}
private_key_.reset(RsaPrivateKey::Create(test_set_.private_key));
ASSERT_TRUE(private_key_ != NULL);
public_key_.reset(RsaPublicKey::Create(test_set_.public_key));
ASSERT_TRUE(public_key_ != NULL);
}
void TearDown() override {
if (kIsFakePrngSupported)
fake_prng::StopFakePrng();
}
protected:
const RsaTestSet& test_set_;
@ -85,16 +76,23 @@ TEST_P(RsaKeyTest, LoadPrivateKeyInPublicKey) {
TEST_P(RsaKeyTest, EncryptAndDecrypt) {
std::string encrypted_message;
EXPECT_TRUE(public_key_->Encrypt(test_set_.test_message, &encrypted_message));
if (kIsFakePrngSupported) {
EXPECT_EQ(test_set_.encrypted_message, encrypted_message);
}
ASSERT_TRUE(public_key_->Encrypt(test_set_.test_message, &encrypted_message));
std::string decrypted_message;
EXPECT_TRUE(private_key_->Decrypt(encrypted_message, &decrypted_message));
EXPECT_EQ(test_set_.test_message, decrypted_message);
}
TEST_P(RsaKeyTest, DecryptGoldenMessage) {
// This message is from an older version that predates our use of mbedtls,
// but proves that the new system is compatible with the messages produced by
// the old one.
std::string decrypted_message;
EXPECT_TRUE(
private_key_->Decrypt(test_set_.encrypted_message, &decrypted_message));
EXPECT_EQ(test_set_.test_message, decrypted_message);
}
TEST_P(RsaKeyTest, BadEncMessage1) {
// Add a byte to the encrypted message.
std::string bad_enc_message = test_set_.encrypted_message + '\0';
@ -123,14 +121,19 @@ TEST_P(RsaKeyTest, BadEncMessage3) {
TEST_P(RsaKeyTest, SignAndVerify) {
std::string signature;
EXPECT_TRUE(
ASSERT_TRUE(
private_key_->GenerateSignature(test_set_.test_message, &signature));
if (kIsFakePrngSupported) {
EXPECT_EQ(test_set_.signature, signature);
}
EXPECT_TRUE(public_key_->VerifySignature(test_set_.test_message, signature));
}
TEST_P(RsaKeyTest, VerifyGoldenSignature) {
// This signature is from an older version that predates our use of mbedtls,
// but proves that the new system is compatible with the signatures produced
// by the old one.
EXPECT_TRUE(public_key_->VerifySignature(test_set_.test_message,
test_set_.signature));
}
TEST_P(RsaKeyTest, BadSignature1) {
// Add a byte to the signature.
std::string bad_signature = test_set_.signature + '\0';
@ -161,5 +164,6 @@ INSTANTIATE_TEST_CASE_P(RsaTestKeys,
::testing::Values(RsaTestData().test_set_3072_bits(),
RsaTestData().test_set_2048_bits()));
} // namespace
} // namespace media
} // namespace shaka

View File

@ -1,75 +0,0 @@
// Copyright 2014 Google LLC. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
#include "packager/media/base/test/fake_prng.h"
#include <openssl/rand.h>
#include "packager/base/files/file_util.h"
#include "packager/base/logging.h"
#include "packager/media/test/test_data_util.h"
namespace {
FILE* g_rand_source_fp = NULL;
const char kFakePrngDataFile[] = "fake_prng_data.bin";
// RAND_bytes and RAND_pseudorand implementation.
int FakeBytes(uint8_t* buf, size_t num) {
DCHECK(buf);
DCHECK(g_rand_source_fp);
if (fread(buf, 1, num, g_rand_source_fp) < static_cast<size_t>(num)) {
LOG(ERROR) << "Ran out of fake PRNG data";
return 0;
}
return 1;
}
const RAND_METHOD kFakeRandMethod = {NULL, // RAND_seed function.
FakeBytes, // RAND_bytes function.
NULL, // RAND_cleanup function.
NULL, // RAND_add function.
FakeBytes, // RAND_pseudorand function.
NULL}; // RAND_status function.
} // namespace
namespace shaka {
namespace media {
namespace fake_prng {
bool StartFakePrng() {
if (g_rand_source_fp) {
LOG(ERROR) << "Fake PRNG already started.";
return false;
}
// Open deterministic random data file and set the OpenSSL PRNG.
g_rand_source_fp =
base::OpenFile(GetTestDataFilePath(kFakePrngDataFile), "rb");
if (!g_rand_source_fp) {
LOG(ERROR) << "Cannot open " << kFakePrngDataFile;
return false;
}
RAND_set_rand_method(&kFakeRandMethod);
return true;
}
void StopFakePrng() {
if (g_rand_source_fp) {
base::CloseFile(g_rand_source_fp);
g_rand_source_fp = NULL;
} else {
LOG(WARNING) << "Fake PRNG not started.";
}
RAND_set_rand_method(RAND_SSLeay());
}
} // namespace fake_prng
} // namespace media
} // namespace shaka

View File

@ -1,27 +0,0 @@
// Copyright 2014 Google LLC. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
//
// Fake, deterministic PRNG for OpenSSL to be used for unit testing.
#ifndef PACKAGER_MEDIA_BASE_FAKE_PRNG_H
#define PACKAGER_MEDIA_BASE_FAKE_PRNG_H
namespace shaka {
namespace media {
namespace fake_prng {
/// Start using fake, deterministic PRNG for OpenSSL.
/// @return true if successful, false otherwise.
bool StartFakePrng();
/// Stop using fake, deterministic PRNG for OpenSSL.
void StopFakePrng();
} // namespace fake_prng
} // namespace media
} // namespace shaka
#endif // PACKAGER_MEDIA_BASE_FAKE_PRNG_H

View File

@ -360,7 +360,6 @@ const uint8_t kTestEncryptedMessage_3072[] = {
0x0a, 0xed, 0x2a, 0xa3, 0xec, 0x97, 0x01, 0xfb, 0xee, 0x28, 0xd7, 0xfc,
0x34, 0xd5, 0x1a, 0x62, 0x9c, 0xb2, 0x9d, 0x8b, 0xe9, 0x49, 0x48, 0x1d};
// Self-generated test vector. Used to verify algorithm stability.
const uint8_t kTestEncryptedMessage_2048[] = {
0x73, 0x37, 0xa5, 0xe3, 0x73, 0xbb, 0xa7, 0xbf, 0xb1, 0xfc, 0x98, 0x6c,
0xd2, 0x20, 0xe2, 0x79, 0xea, 0x90, 0x41, 0xcf, 0x2b, 0xe0, 0x22, 0x0f,

View File

@ -11,7 +11,7 @@
#include <string>
#include "packager/base/macros.h"
#include "packager/macros.h"
namespace shaka {
namespace media {