Move encryptor setup out of EncryptionHandler

Created EncryptorFactory to set up the encryptors.

This is part of the EncryptionHandler clean up to make it more modular
and testable.

Change-Id: I839bcd8a84fa873396360d67afb540fef1345673
This commit is contained in:
KongQun Yang 2018-10-04 13:24:21 -07:00
parent 8d11e5ea64
commit acaa6b9b3b
9 changed files with 322 additions and 346 deletions

View File

@ -9,6 +9,7 @@
#include "packager/base/strings/string_number_conversions.h" #include "packager/base/strings/string_number_conversions.h"
#include "packager/media/base/aes_pattern_cryptor.h" #include "packager/media/base/aes_pattern_cryptor.h"
#include "packager/media/base/mock_aes_cryptor.h"
using ::testing::_; using ::testing::_;
using ::testing::Invoke; using ::testing::Invoke;
@ -22,21 +23,6 @@ const uint8_t kSkipByteBlock = 1u;
namespace shaka { namespace shaka {
namespace media { namespace media {
class MockAesCryptor : public AesCryptor {
public:
MockAesCryptor() : AesCryptor(kDontUseConstantIv) {}
MOCK_METHOD2(InitializeWithIv,
bool(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv));
MOCK_METHOD4(CryptInternal,
bool(const uint8_t* text,
size_t text_size,
uint8_t* crypt_text,
size_t* crypt_text_size));
MOCK_METHOD0(SetIvInternal, void());
};
class AesPatternCryptorTest : public ::testing::Test { class AesPatternCryptorTest : public ::testing::Test {
public: public:
AesPatternCryptorTest() AesPatternCryptorTest()

View File

@ -0,0 +1,33 @@
// Copyright 2018 Google Inc. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
#ifndef PACKAGER_MEDIA_BASE_MOCK_AES_CRYPTOR_H_
#define PACKAGER_MEDIA_BASE_MOCK_AES_CRYPTOR_H_
#include "packager/media/base/aes_cryptor.h"
namespace shaka {
namespace media {
class MockAesCryptor : public AesCryptor {
public:
MockAesCryptor() : AesCryptor(kDontUseConstantIv) {}
MOCK_METHOD2(InitializeWithIv,
bool(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv));
MOCK_METHOD4(CryptInternal,
bool(const uint8_t* text,
size_t text_size,
uint8_t* crypt_text,
size_t* crypt_text_size));
MOCK_METHOD0(SetIvInternal, void());
};
} // namespace media
} // namespace shaka
#endif // PACKAGER_MEDIA_BASE_MOCK_AES_CRYPTOR_H_

View File

@ -0,0 +1,84 @@
// Copyright 2018 Google LLC. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
#include "packager/media/crypto/aes_encryptor_factory.h"
#include "packager/media/base/aes_encryptor.h"
#include "packager/media/base/aes_pattern_cryptor.h"
#include "packager/media/crypto/sample_aes_ec3_cryptor.h"
namespace shaka {
namespace media {
std::unique_ptr<AesCryptor> AesEncryptorFactory::CreateEncryptor(
FourCC protection_scheme,
uint8_t crypt_byte_block,
uint8_t skip_byte_block,
Codec codec,
const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv) {
std::unique_ptr<AesCryptor> encryptor;
switch (protection_scheme) {
case FOURCC_cenc:
encryptor.reset(new AesCtrEncryptor);
break;
case FOURCC_cbc1:
encryptor.reset(new AesCbcEncryptor(kNoPadding));
break;
case FOURCC_cens:
encryptor.reset(new AesPatternCryptor(
crypt_byte_block, skip_byte_block,
AesPatternCryptor::kEncryptIfCryptByteBlockRemaining,
AesCryptor::kDontUseConstantIv,
std::unique_ptr<AesCryptor>(new AesCtrEncryptor)));
break;
case FOURCC_cbcs:
encryptor.reset(new AesPatternCryptor(
crypt_byte_block, skip_byte_block,
AesPatternCryptor::kEncryptIfCryptByteBlockRemaining,
AesCryptor::kUseConstantIv,
std::unique_ptr<AesCryptor>(new AesCbcEncryptor(kNoPadding))));
break;
case kAppleSampleAesProtectionScheme:
if (crypt_byte_block == 0 && skip_byte_block == 0) {
if (codec == kCodecEAC3) {
encryptor.reset(new SampleAesEc3Cryptor(
std::unique_ptr<AesCryptor>(new AesCbcEncryptor(kNoPadding))));
} else {
encryptor.reset(
new AesCbcEncryptor(kNoPadding, AesCryptor::kUseConstantIv));
}
} else {
encryptor.reset(new AesPatternCryptor(
crypt_byte_block, skip_byte_block,
AesPatternCryptor::kSkipIfCryptByteBlockRemaining,
AesCryptor::kUseConstantIv,
std::unique_ptr<AesCryptor>(new AesCbcEncryptor(kNoPadding))));
}
break;
default:
LOG(ERROR) << "Unsupported protection scheme.";
return nullptr;
}
if (iv.empty()) {
std::vector<uint8_t> random_iv;
if (!AesCryptor::GenerateRandomIv(protection_scheme, &random_iv)) {
LOG(ERROR) << "Failed to generate random iv.";
return nullptr;
}
if (!encryptor->InitializeWithIv(key, random_iv))
return nullptr;
} else {
if (!encryptor->InitializeWithIv(key, iv))
return nullptr;
}
return encryptor;
}
} // namespace media
} // namespace shaka

View File

@ -0,0 +1,41 @@
// Copyright 2018 Google LLC. All rights reserved.
//
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file or at
// https://developers.google.com/open-source/licenses/bsd
#ifndef PACKAGER_MEDIA_CRYPTO_AES_ENCRYPTOR_FACTORY_H_
#define PACKAGER_MEDIA_CRYPTO_AES_ENCRYPTOR_FACTORY_H_
#include "packager/media/base/fourccs.h"
#include "packager/media/base/stream_info.h"
namespace shaka {
namespace media {
class AesCryptor;
/// A factory class to create encryptors.
class AesEncryptorFactory {
public:
AesEncryptorFactory() = default;
virtual ~AesEncryptorFactory() = default;
// Virtual for mocking.
virtual std::unique_ptr<AesCryptor> CreateEncryptor(
FourCC protection_scheme,
uint8_t crypt_byte_block,
uint8_t skip_byte_block,
Codec codec,
const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv);
private:
AesEncryptorFactory(const AesEncryptorFactory&) = delete;
AesEncryptorFactory& operator=(const AesEncryptorFactory&) = delete;
};
} // namespace media
} // namespace shaka
#endif // PACKAGER_MEDIA_CRYPTO_AES_ENCRYPTOR_FACTORY_H_

View File

@ -13,6 +13,8 @@
'target_name': 'crypto', 'target_name': 'crypto',
'type': '<(component)', 'type': '<(component)',
'sources': [ 'sources': [
'aes_encryptor_factory.cc',
'aes_encryptor_factory.h',
'encryption_handler.cc', 'encryption_handler.cc',
'encryption_handler.h', 'encryption_handler.h',
'sample_aes_ec3_cryptor.cc', 'sample_aes_ec3_cryptor.cc',

View File

@ -13,7 +13,6 @@
#include <limits> #include <limits>
#include "packager/media/base/aes_encryptor.h" #include "packager/media/base/aes_encryptor.h"
#include "packager/media/base/aes_pattern_cryptor.h"
#include "packager/media/base/audio_stream_info.h" #include "packager/media/base/audio_stream_info.h"
#include "packager/media/base/key_source.h" #include "packager/media/base/key_source.h"
#include "packager/media/base/media_sample.h" #include "packager/media/base/media_sample.h"
@ -21,7 +20,7 @@
#include "packager/media/codecs/video_slice_header_parser.h" #include "packager/media/codecs/video_slice_header_parser.h"
#include "packager/media/codecs/vp8_parser.h" #include "packager/media/codecs/vp8_parser.h"
#include "packager/media/codecs/vp9_parser.h" #include "packager/media/codecs/vp9_parser.h"
#include "packager/media/crypto/sample_aes_ec3_cryptor.h" #include "packager/media/crypto/aes_encryptor_factory.h"
#include "packager/status_macros.h" #include "packager/status_macros.h"
namespace shaka { namespace shaka {
@ -89,6 +88,12 @@ std::string GetStreamLabelForEncryption(
} }
return stream_label_func(stream_attributes); return stream_label_func(stream_attributes);
} }
bool IsPatternEncryptionScheme(FourCC protection_scheme) {
return protection_scheme == kAppleSampleAesProtectionScheme ||
protection_scheme == FOURCC_cbcs || protection_scheme == FOURCC_cens;
}
} // namespace } // namespace
EncryptionHandler::EncryptionHandler(const EncryptionParams& encryption_params, EncryptionHandler::EncryptionHandler(const EncryptionParams& encryption_params,
@ -96,9 +101,10 @@ EncryptionHandler::EncryptionHandler(const EncryptionParams& encryption_params,
: encryption_params_(encryption_params), : encryption_params_(encryption_params),
protection_scheme_( protection_scheme_(
static_cast<FourCC>(encryption_params.protection_scheme)), static_cast<FourCC>(encryption_params.protection_scheme)),
key_source_(key_source) {} key_source_(key_source),
encryptor_factory_(new AesEncryptorFactory) {}
EncryptionHandler::~EncryptionHandler() {} EncryptionHandler::~EncryptionHandler() = default;
Status EncryptionHandler::InitializeInternal() { Status EncryptionHandler::InitializeInternal() {
if (!encryption_params_.stream_label_func) { if (!encryption_params_.stream_label_func) {
@ -311,140 +317,71 @@ Status EncryptionHandler::ProcessMediaSample(
} }
Status EncryptionHandler::SetupProtectionPattern(StreamType stream_type) { Status EncryptionHandler::SetupProtectionPattern(StreamType stream_type) {
switch (protection_scheme_) { if (protection_scheme_ == kAppleSampleAesProtectionScheme) {
case kAppleSampleAesProtectionScheme: { const size_t kH264LeadingClearBytesSize = 32u;
const size_t kH264LeadingClearBytesSize = 32u; const size_t kSmallNalUnitSize = 32u + 16u;
const size_t kSmallNalUnitSize = 32u + 16u; const size_t kAudioLeadingClearBytesSize = 16u;
const size_t kAudioLeadingClearBytesSize = 16u; switch (codec_) {
switch (codec_) { case kCodecH264:
case kCodecH264: leading_clear_bytes_size_ = kH264LeadingClearBytesSize;
// Apple Sample AES uses 1:9 pattern for video. min_protected_data_size_ = kSmallNalUnitSize + 1u;
crypt_byte_block_ = 1u; break;
skip_byte_block_ = 9u; case kCodecAAC:
leading_clear_bytes_size_ = kH264LeadingClearBytesSize; FALLTHROUGH_INTENDED;
min_protected_data_size_ = kSmallNalUnitSize + 1u; case kCodecAC3:
break; FALLTHROUGH_INTENDED;
case kCodecAAC: case kCodecEAC3:
FALLTHROUGH_INTENDED; // E-AC3 encryption is handled by SampleAesEc3Cryptor, which also
case kCodecAC3: // manages leading clear bytes.
FALLTHROUGH_INTENDED; leading_clear_bytes_size_ =
case kCodecEAC3: codec_ == kCodecEAC3 ? 0 : kAudioLeadingClearBytesSize;
// Audio is whole sample encrypted. We could not use a min_protected_data_size_ = leading_clear_bytes_size_ + 15u;
// crypto_byte_block_ of 1 here as if there is one crypto block break;
// remaining, it need not be encrypted for video but it needs to be default:
// encrypted for audio. return Status(
crypt_byte_block_ = 0u; error::ENCRYPTION_FAILURE,
skip_byte_block_ = 0u; "Only AAC/AC3/EAC3 and H264 are supported in Sample AES.");
// E-AC3 encryption is handled by SampleAesEc3Cryptor, which also
// manages leading clear bytes.
leading_clear_bytes_size_ =
codec_ == kCodecEAC3 ? 0 : kAudioLeadingClearBytesSize;
min_protected_data_size_ = leading_clear_bytes_size_ + 15u;
break;
default:
return Status(
error::ENCRYPTION_FAILURE,
"Only AAC/AC3/EAC3 and H264 are supported in Sample AES.");
}
break;
} }
case FOURCC_cbcs: }
FALLTHROUGH_INTENDED; if (stream_type == kStreamVideo &&
case FOURCC_cens: IsPatternEncryptionScheme(protection_scheme_)) {
if (stream_type == kStreamVideo) { // Use 1:9 pattern.
// Use 1:9 pattern for video. crypt_byte_block_ = 1u;
crypt_byte_block_ = 1u; skip_byte_block_ = 9u;
skip_byte_block_ = 9u; } else {
} else { // Audio stream in pattern encryption scheme does not use pattern; it uses
// Tracks other than video are protected using whole-block full-sample // whole-block full sample encryption instead. Non-pattern encryption does
// encryption. Note that this may not be the same as the non-pattern // not have pattern.
// based encryption counterparts, e.g. in 'cens' whole-block full sample crypt_byte_block_ = 0u;
// encryption, the whole sample is encrypted up to the last 16-byte skip_byte_block_ = 0u;
// boundary, see 23001-7:2016(E) 9.7; while in 'cenc' full sample
// encryption, the last partial 16-byte block is also encrypted, see
// 23001-7:2016(E) 9.4.2. Another difference is the use of constant iv.
crypt_byte_block_ = 0u;
skip_byte_block_ = 0u;
}
break;
default:
// Not using pattern encryption.
crypt_byte_block_ = 0u;
skip_byte_block_ = 0u;
break;
} }
return Status::OK; return Status::OK;
} }
bool EncryptionHandler::CreateEncryptor(const EncryptionKey& encryption_key) { bool EncryptionHandler::CreateEncryptor(const EncryptionKey& encryption_key) {
std::unique_ptr<AesCryptor> encryptor; std::unique_ptr<AesCryptor> encryptor = encryptor_factory_->CreateEncryptor(
switch (protection_scheme_) { protection_scheme_, crypt_byte_block_, skip_byte_block_, codec_,
case FOURCC_cenc: encryption_key.key, encryption_key.iv);
encryptor.reset(new AesCtrEncryptor); if (!encryptor)
break; return false;
case FOURCC_cbc1:
encryptor.reset(new AesCbcEncryptor(kNoPadding));
break;
case FOURCC_cens:
encryptor.reset(new AesPatternCryptor(
crypt_byte_block_, skip_byte_block_,
AesPatternCryptor::kEncryptIfCryptByteBlockRemaining,
AesCryptor::kDontUseConstantIv,
std::unique_ptr<AesCryptor>(new AesCtrEncryptor())));
break;
case FOURCC_cbcs:
encryptor.reset(new AesPatternCryptor(
crypt_byte_block_, skip_byte_block_,
AesPatternCryptor::kEncryptIfCryptByteBlockRemaining,
AesCryptor::kUseConstantIv,
std::unique_ptr<AesCryptor>(new AesCbcEncryptor(kNoPadding))));
break;
case kAppleSampleAesProtectionScheme:
if (crypt_byte_block_ == 0 && skip_byte_block_ == 0) {
if (codec_ == kCodecEAC3) {
encryptor.reset(new SampleAesEc3Cryptor(
std::unique_ptr<AesCryptor>(new AesCbcEncryptor(kNoPadding))));
} else {
encryptor.reset(
new AesCbcEncryptor(kNoPadding, AesCryptor::kUseConstantIv));
}
} else {
encryptor.reset(new AesPatternCryptor(
crypt_byte_block_, skip_byte_block_,
AesPatternCryptor::kSkipIfCryptByteBlockRemaining,
AesCryptor::kUseConstantIv,
std::unique_ptr<AesCryptor>(new AesCbcEncryptor(kNoPadding))));
}
break;
default:
LOG(ERROR) << "Unsupported protection scheme.";
return false;
}
std::vector<uint8_t> iv = encryption_key.iv;
if (iv.empty()) {
if (!AesCryptor::GenerateRandomIv(protection_scheme_, &iv)) {
LOG(ERROR) << "Failed to generate random iv.";
return false;
}
}
const bool initialized =
encryptor->InitializeWithIv(encryption_key.key, iv);
encryptor_ = std::move(encryptor); encryptor_ = std::move(encryptor);
encryption_config_.reset(new EncryptionConfig); encryption_config_.reset(new EncryptionConfig);
encryption_config_->protection_scheme = protection_scheme_; encryption_config_->protection_scheme = protection_scheme_;
encryption_config_->crypt_byte_block = crypt_byte_block_; encryption_config_->crypt_byte_block = crypt_byte_block_;
encryption_config_->skip_byte_block = skip_byte_block_; encryption_config_->skip_byte_block = skip_byte_block_;
const std::vector<uint8_t>& iv = encryptor_->iv();
if (encryptor_->use_constant_iv()) { if (encryptor_->use_constant_iv()) {
encryption_config_->per_sample_iv_size = 0; encryption_config_->per_sample_iv_size = 0;
encryption_config_->constant_iv = iv; encryption_config_->constant_iv = iv;
} else { } else {
encryption_config_->per_sample_iv_size = static_cast<uint8_t>(iv.size()); encryption_config_->per_sample_iv_size = static_cast<uint8_t>(iv.size());
} }
encryption_config_->key_id = encryption_key.key_id; encryption_config_->key_id = encryption_key.key_id;
encryption_config_->key_system_info = encryption_key.key_system_info; encryption_config_->key_system_info = encryption_key.key_system_info;
return initialized; return true;
} }
bool EncryptionHandler::EncryptVpxFrame( bool EncryptionHandler::EncryptVpxFrame(
@ -584,5 +521,10 @@ void EncryptionHandler::InjectVideoSliceHeaderParserForTesting(
header_parser_ = std::move(header_parser); header_parser_ = std::move(header_parser);
} }
void EncryptionHandler::InjectEncryptorFactoryForTesting(
std::unique_ptr<AesEncryptorFactory> encryptor_factory) {
encryptor_factory_ = std::move(encryptor_factory);
}
} // namespace media } // namespace media
} // namespace shaka } // namespace shaka

View File

@ -15,6 +15,7 @@ namespace shaka {
namespace media { namespace media {
class AesCryptor; class AesCryptor;
class AesEncryptorFactory;
class VideoSliceHeaderParser; class VideoSliceHeaderParser;
class VPxParser; class VPxParser;
struct EncryptionKey; struct EncryptionKey;
@ -80,6 +81,8 @@ class EncryptionHandler : public MediaHandler {
void InjectVpxParserForTesting(std::unique_ptr<VPxParser> vpx_parser); void InjectVpxParserForTesting(std::unique_ptr<VPxParser> vpx_parser);
void InjectVideoSliceHeaderParserForTesting( void InjectVideoSliceHeaderParserForTesting(
std::unique_ptr<VideoSliceHeaderParser> header_parser); std::unique_ptr<VideoSliceHeaderParser> header_parser);
void InjectEncryptorFactoryForTesting(
std::unique_ptr<AesEncryptorFactory> encryptor_factory);
const EncryptionParams encryption_params_; const EncryptionParams encryption_params_;
const FourCC protection_scheme_ = FOURCC_NULL; const FourCC protection_scheme_ = FOURCC_NULL;
@ -111,6 +114,7 @@ class EncryptionHandler : public MediaHandler {
/// Number of unencrypted blocks (16-byte-block) in pattern based encryption. /// Number of unencrypted blocks (16-byte-block) in pattern based encryption.
uint8_t skip_byte_block_ = 0; uint8_t skip_byte_block_ = 0;
std::unique_ptr<AesEncryptorFactory> encryptor_factory_;
// VPx parser for VPx streams. // VPx parser for VPx streams.
std::unique_ptr<VPxParser> vpx_parser_; std::unique_ptr<VPxParser> vpx_parser_;
// Video slice header parser for NAL strucutred streams. // Video slice header parser for NAL strucutred streams.

View File

@ -9,12 +9,13 @@
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "packager/media/base/aes_decryptor.h" #include "packager/media/base/aes_cryptor.h"
#include "packager/media/base/aes_pattern_cryptor.h"
#include "packager/media/base/media_handler_test_base.h" #include "packager/media/base/media_handler_test_base.h"
#include "packager/media/base/mock_aes_cryptor.h"
#include "packager/media/base/raw_key_source.h" #include "packager/media/base/raw_key_source.h"
#include "packager/media/codecs/video_slice_header_parser.h" #include "packager/media/codecs/video_slice_header_parser.h"
#include "packager/media/codecs/vpx_parser.h" #include "packager/media/codecs/vpx_parser.h"
#include "packager/media/crypto/aes_encryptor_factory.h"
#include "packager/status_test_util.h" #include "packager/status_test_util.h"
namespace shaka { namespace shaka {
@ -22,9 +23,11 @@ namespace media {
namespace { namespace {
using ::testing::_; using ::testing::_;
using ::testing::ByMove;
using ::testing::Combine; using ::testing::Combine;
using ::testing::DoAll; using ::testing::DoAll;
using ::testing::ElementsAre; using ::testing::ElementsAre;
using ::testing::Invoke;
using ::testing::Mock; using ::testing::Mock;
using ::testing::Return; using ::testing::Return;
using ::testing::SetArgPointee; using ::testing::SetArgPointee;
@ -33,6 +36,8 @@ using ::testing::Values;
using ::testing::ValuesIn; using ::testing::ValuesIn;
using ::testing::WithParamInterface; using ::testing::WithParamInterface;
const size_t kStreamIndex = 0;
const uint32_t kTimeScale = 1000;
const char kAudioStreamLabel[] = "AUDIO"; const char kAudioStreamLabel[] = "AUDIO";
const char kSdVideoStreamLabel[] = "SD"; const char kSdVideoStreamLabel[] = "SD";
@ -79,6 +84,17 @@ class MockVideoSliceHeaderParser : public VideoSliceHeaderParser {
MOCK_METHOD1(GetHeaderSize, int64_t(const Nalu& nalu)); MOCK_METHOD1(GetHeaderSize, int64_t(const Nalu& nalu));
}; };
class MockAesEncryptorFactory : public AesEncryptorFactory {
public:
MOCK_METHOD6(CreateEncryptor,
std::unique_ptr<AesCryptor>(FourCC protection_scheme,
uint8_t crypt_byte_block,
uint8_t skip_byte_block,
Codec codec,
const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv));
};
} // namespace } // namespace
class EncryptionHandlerTest : public MediaHandlerGraphTestBase { class EncryptionHandlerTest : public MediaHandlerGraphTestBase {
@ -126,6 +142,12 @@ class EncryptionHandlerTest : public MediaHandlerGraphTestBase {
std::move(header_parser)); std::move(header_parser));
} }
void InjectEncryptorFactoryForTesting(
std::unique_ptr<AesEncryptorFactory> encryptor_factory) {
encryption_handler_->InjectEncryptorFactoryForTesting(
std::move(encryptor_factory));
}
protected: protected:
std::shared_ptr<EncryptionHandler> encryption_handler_; std::shared_ptr<EncryptionHandler> encryption_handler_;
StrictMock<MockKeySource> mock_key_source_; StrictMock<MockKeySource> mock_key_source_;
@ -148,15 +170,39 @@ TEST_F(EncryptionHandlerTest, OnlyOneInput) {
encryption_handler_->Initialize().error_code()); encryption_handler_->Initialize().error_code());
} }
TEST_F(EncryptionHandlerTest, GetKeyFailed) {
const EncryptionKey mock_encryption_key = GetMockEncryptionKey();
EXPECT_CALL(mock_key_source_, GetKey(_, _))
.WillOnce(Return(Status(error::INVALID_ARGUMENT, "")));
ASSERT_NOT_OK(Process(StreamData::FromStreamInfo(
kStreamIndex, GetVideoStreamInfo(kTimeScale, kCodecH264))));
}
TEST_F(EncryptionHandlerTest, CreateEncryptorFailed) {
const EncryptionKey mock_encryption_key = GetMockEncryptionKey();
EXPECT_CALL(mock_key_source_, GetKey(_, _))
.WillOnce(
DoAll(SetArgPointee<1>(mock_encryption_key), Return(Status::OK)));
std::unique_ptr<MockAesEncryptorFactory> mock_encryptor_factory(
new MockAesEncryptorFactory);
EXPECT_CALL(*mock_encryptor_factory,
CreateEncryptor(_, _, _, _, mock_encryption_key.key,
mock_encryption_key.iv))
.WillOnce(Return(ByMove(nullptr)));
InjectEncryptorFactoryForTesting(std::move(mock_encryptor_factory));
ASSERT_NOT_OK(Process(StreamData::FromStreamInfo(
kStreamIndex, GetVideoStreamInfo(kTimeScale, kCodecH264))));
}
namespace { namespace {
const bool kVp9SubsampleEncryption = true; const bool kVp9SubsampleEncryption = true;
const bool kIsKeyFrame = true; const bool kIsKeyFrame = true;
const bool kIsSubsegment = true; const bool kIsSubsegment = true;
const bool kEncrypted = true; const bool kEncrypted = true;
const size_t kStreamIndex = 0;
const uint32_t kTimeScale = 1000;
const int64_t kSampleDuration = 1000;
const int64_t kSegmentDuration = 1000; const int64_t kSegmentDuration = 1000;
// The data is based on H264. The same data is also used to test audio, which // The data is based on H264. The same data is also used to test audio, which
@ -179,9 +225,6 @@ const uint8_t kData[]{
0x06, 0x67, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x06, 0x67, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
}; };
const size_t kDataSize = sizeof(kData); const size_t kDataSize = sizeof(kData);
// A short data size (less than leading clear bytes) for SampleAes audio
// testing.
const size_t kShortDataSize = 14;
// H264 subsample information for the the above data. // H264 subsample information for the the above data.
const size_t kNaluLengthSize = 1u; const size_t kNaluLengthSize = 1u;
@ -306,106 +349,14 @@ class EncryptionHandlerEncryptionTest
} }
} }
bool Decrypt(const DecryptConfig& decrypt_config,
uint8_t* data,
size_t data_size) {
size_t leading_clear_bytes_size = 0;
std::unique_ptr<AesCryptor> aes_decryptor;
switch (decrypt_config.protection_scheme()) {
case FOURCC_cenc:
aes_decryptor.reset(new AesCtrDecryptor);
break;
case FOURCC_cbc1:
aes_decryptor.reset(new AesCbcDecryptor(kNoPadding));
break;
case FOURCC_cens:
aes_decryptor.reset(new AesPatternCryptor(
decrypt_config.crypt_byte_block(), decrypt_config.skip_byte_block(),
AesPatternCryptor::kEncryptIfCryptByteBlockRemaining,
AesCryptor::kDontUseConstantIv,
std::unique_ptr<AesCryptor>(new AesCtrDecryptor())));
break;
case FOURCC_cbcs:
aes_decryptor.reset(new AesPatternCryptor(
decrypt_config.crypt_byte_block(), decrypt_config.skip_byte_block(),
AesPatternCryptor::kEncryptIfCryptByteBlockRemaining,
AesCryptor::kUseConstantIv,
std::unique_ptr<AesCryptor>(new AesCbcDecryptor(kNoPadding))));
break;
case kAppleSampleAesProtectionScheme:
if (decrypt_config.crypt_byte_block() == 0 &&
decrypt_config.skip_byte_block() == 0) {
const size_t kAudioLeadingClearBytesSize = 16u;
// Only needed for audio; for video, it is already taken into
// consideration in subsamples.
leading_clear_bytes_size = kAudioLeadingClearBytesSize;
aes_decryptor.reset(
new AesCbcDecryptor(kNoPadding, AesCryptor::kUseConstantIv));
} else {
aes_decryptor.reset(new AesPatternCryptor(
decrypt_config.crypt_byte_block(),
decrypt_config.skip_byte_block(),
AesPatternCryptor::kSkipIfCryptByteBlockRemaining,
AesCryptor::kUseConstantIv,
std::unique_ptr<AesCryptor>(new AesCbcDecryptor(kNoPadding))));
}
break;
default:
LOG(FATAL) << "Not supposed to happen.";
}
if (!aes_decryptor->InitializeWithIv(
std::vector<uint8_t>(kKey, kKey + sizeof(kKey)),
decrypt_config.iv())) {
return false;
}
if (decrypt_config.subsamples().empty()) {
// Sample not encrypted using subsample encryption. Decrypt whole.
if (!aes_decryptor->Crypt(data + leading_clear_bytes_size,
data_size - leading_clear_bytes_size,
data + leading_clear_bytes_size)) {
LOG(ERROR) << "Error during bulk sample decryption.";
return false;
}
return true;
}
// Subsample decryption.
const std::vector<SubsampleEntry>& subsamples = decrypt_config.subsamples();
uint8_t* current_ptr = data;
const uint8_t* const buffer_end = data + data_size;
for (const auto& subsample : subsamples) {
if (current_ptr + subsample.clear_bytes + subsample.cipher_bytes >
buffer_end) {
LOG(ERROR) << "Subsamples overflow sample buffer.";
return false;
}
current_ptr += subsample.clear_bytes;
if (!aes_decryptor->Crypt(current_ptr, subsample.cipher_bytes,
current_ptr)) {
LOG(ERROR) << "Error decrypting subsample buffer.";
return false;
}
current_ptr += subsample.cipher_bytes;
}
return true;
}
uint8_t GetExpectedCryptByteBlock() { uint8_t GetExpectedCryptByteBlock() {
if (protection_scheme_ == kAppleSampleAesProtectionScheme) {
// Audio is whole sample encrypted. We could not use a
// crypto_byte_block_ of 1 for audio as if there is one crypto block
// remaining, it need not be encrypted for video but it needs to be
// encrypted for audio.
return codec_ == kCodecAAC ? 0 : 1;
}
switch (protection_scheme_) { switch (protection_scheme_) {
case FOURCC_cenc: case FOURCC_cenc:
case FOURCC_cbc1: case FOURCC_cbc1:
return 0; return 0;
case FOURCC_cens: case FOURCC_cens:
case FOURCC_cbcs: case FOURCC_cbcs:
case kAppleSampleAesProtectionScheme:
return codec_ == kCodecAAC ? 0 : 1; return codec_ == kCodecAAC ? 0 : 1;
default: default:
return 0; return 0;
@ -459,6 +410,35 @@ class EncryptionHandlerEncryptionTest
bool vp9_subsample_encryption_; bool vp9_subsample_encryption_;
}; };
TEST_P(EncryptionHandlerEncryptionTest, VerifyEncryptorFactoryParams) {
EncryptionParams encryption_params;
encryption_params.protection_scheme = protection_scheme_;
SetUpEncryptionHandler(encryption_params);
const EncryptionKey mock_encryption_key = GetMockEncryptionKey();
EXPECT_CALL(mock_key_source_, GetKey(_, _))
.WillOnce(
DoAll(SetArgPointee<1>(mock_encryption_key), Return(Status::OK)));
std::unique_ptr<MockAesCryptor> mock_encryptor(new MockAesCryptor);
std::unique_ptr<MockAesEncryptorFactory> mock_encryptor_factory(
new MockAesEncryptorFactory);
EXPECT_CALL(*mock_encryptor_factory,
CreateEncryptor(protection_scheme_, GetExpectedCryptByteBlock(),
GetExpectedSkipByteBlock(), codec_,
mock_encryption_key.key, mock_encryption_key.iv))
.WillOnce(Return(ByMove(std::move(mock_encryptor))));
InjectEncryptorFactoryForTesting(std::move(mock_encryptor_factory));
if (IsVideoCodec(codec_)) {
ASSERT_OK(Process(StreamData::FromStreamInfo(
kStreamIndex, GetVideoStreamInfo(kTimeScale, codec_))));
} else {
ASSERT_OK(Process(StreamData::FromStreamInfo(
kStreamIndex, GetAudioStreamInfo(kTimeScale, codec_))));
}
}
TEST_P(EncryptionHandlerEncryptionTest, ClearLeadWithNoKeyRotation) { TEST_P(EncryptionHandlerEncryptionTest, ClearLeadWithNoKeyRotation) {
const double kClearLeadInSeconds = 1.5 * kSegmentDuration / kTimeScale; const double kClearLeadInSeconds = 1.5 * kSegmentDuration / kTimeScale;
EncryptionParams encryption_params; EncryptionParams encryption_params;
@ -514,6 +494,19 @@ TEST_P(EncryptionHandlerEncryptionTest, ClearLeadWithNoKeyRotation) {
IsSegmentInfo(kStreamIndex, i * kSegmentDuration, IsSegmentInfo(kStreamIndex, i * kSegmentDuration,
kSegmentDuration, !kIsSubsegment, kSegmentDuration, !kIsSubsegment,
is_encrypted))); is_encrypted)));
if (is_encrypted) {
const auto* media_sample = output_stream_data.front()->media_sample.get();
const auto* decrypt_config = media_sample->decrypt_config();
EXPECT_EQ(std::vector<uint8_t>(kKeyId, kKeyId + sizeof(kKeyId)),
decrypt_config->key_id());
EXPECT_EQ(std::vector<uint8_t>(kIv, kIv + sizeof(kIv)),
decrypt_config->iv());
EXPECT_EQ(GetExpectedSubsamples(), decrypt_config->subsamples());
EXPECT_EQ(protection_scheme_, decrypt_config->protection_scheme());
EXPECT_EQ(GetExpectedCryptByteBlock(),
decrypt_config->crypt_byte_block());
EXPECT_EQ(GetExpectedSkipByteBlock(), decrypt_config->skip_byte_block());
}
EXPECT_FALSE(output_stream_data.back() EXPECT_FALSE(output_stream_data.back()
->segment_info->key_rotation_encryption_config); ->segment_info->key_rotation_encryption_config);
ClearOutputStreamDataVector(); ClearOutputStreamDataVector();
@ -593,99 +586,6 @@ TEST_P(EncryptionHandlerEncryptionTest, ClearLeadWithKeyRotation) {
} }
} }
TEST_P(EncryptionHandlerEncryptionTest, Encrypt) {
EncryptionParams encryption_params;
encryption_params.protection_scheme = protection_scheme_;
encryption_params.vp9_subsample_encryption = vp9_subsample_encryption_;
SetUpEncryptionHandler(encryption_params);
const EncryptionKey mock_encryption_key = GetMockEncryptionKey();
EXPECT_CALL(mock_key_source_, GetKey(_, _))
.WillOnce(
DoAll(SetArgPointee<1>(mock_encryption_key), Return(Status::OK)));
if (IsVideoCodec(codec_)) {
ASSERT_OK(Process(StreamData::FromStreamInfo(
kStreamIndex, GetVideoStreamInfo(kTimeScale, codec_))));
} else {
ASSERT_OK(Process(StreamData::FromStreamInfo(
kStreamIndex, GetAudioStreamInfo(kTimeScale, codec_))));
}
EXPECT_THAT(
GetOutputStreamDataVector(),
ElementsAre(IsStreamInfo(kStreamIndex, kTimeScale, kEncrypted, _)));
const StreamInfo* stream_info =
GetOutputStreamDataVector().back()->stream_info.get();
ASSERT_TRUE(stream_info);
EXPECT_FALSE(stream_info->has_clear_lead());
InjectCodecParser();
ASSERT_OK(Process(StreamData::FromMediaSample(
kStreamIndex,
GetMediaSample(0, kSampleDuration, kIsKeyFrame, kData, kDataSize))));
ASSERT_EQ(2u, GetOutputStreamDataVector().size());
ASSERT_EQ(kStreamIndex, GetOutputStreamDataVector().back()->stream_index);
ASSERT_EQ(StreamDataType::kMediaSample,
GetOutputStreamDataVector().back()->stream_data_type);
auto* media_sample = GetOutputStreamDataVector().back()->media_sample.get();
auto* decrypt_config = media_sample->decrypt_config();
EXPECT_EQ(std::vector<uint8_t>(kKeyId, kKeyId + sizeof(kKeyId)),
decrypt_config->key_id());
EXPECT_EQ(std::vector<uint8_t>(kIv, kIv + sizeof(kIv)), decrypt_config->iv());
EXPECT_EQ(GetExpectedSubsamples(), decrypt_config->subsamples());
EXPECT_EQ(protection_scheme_, decrypt_config->protection_scheme());
EXPECT_EQ(GetExpectedCryptByteBlock(), decrypt_config->crypt_byte_block());
EXPECT_EQ(GetExpectedSkipByteBlock(), decrypt_config->skip_byte_block());
std::vector<uint8_t> expected(kData, kData + kDataSize);
std::vector<uint8_t> actual(media_sample->data(),
media_sample->data() + media_sample->data_size());
ASSERT_TRUE(Decrypt(*decrypt_config, actual.data(), actual.size()));
EXPECT_EQ(expected, actual);
}
// Verify that the data in short audio (less than leading clear bytes) is left
// unencrypted.
TEST_P(EncryptionHandlerEncryptionTest, SampleAesEncryptShortAudio) {
if (IsVideoCodec(codec_) ||
protection_scheme_ != kAppleSampleAesProtectionScheme) {
return;
}
EncryptionParams encryption_params;
encryption_params.protection_scheme = kAppleSampleAesProtectionScheme;
SetUpEncryptionHandler(encryption_params);
const EncryptionKey mock_encryption_key = GetMockEncryptionKey();
EXPECT_CALL(mock_key_source_, GetKey(_, _))
.WillOnce(
DoAll(SetArgPointee<1>(mock_encryption_key), Return(Status::OK)));
ASSERT_OK(Process(StreamData::FromStreamInfo(
kStreamIndex, GetAudioStreamInfo(kTimeScale, codec_))));
ASSERT_OK(Process(StreamData::FromMediaSample(
kStreamIndex,
GetMediaSample(0, kSampleDuration, kIsKeyFrame, kData, kShortDataSize))));
ASSERT_EQ(2u, GetOutputStreamDataVector().size());
ASSERT_EQ(kStreamIndex, GetOutputStreamDataVector().back()->stream_index);
ASSERT_EQ(StreamDataType::kMediaSample,
GetOutputStreamDataVector().back()->stream_data_type);
auto* media_sample = GetOutputStreamDataVector().back()->media_sample.get();
auto* decrypt_config = media_sample->decrypt_config();
EXPECT_TRUE(decrypt_config->subsamples().empty());
EXPECT_EQ(kAppleSampleAesProtectionScheme,
decrypt_config->protection_scheme());
std::vector<uint8_t> expected(kData, kData + kShortDataSize);
std::vector<uint8_t> actual(media_sample->data(),
media_sample->data() + media_sample->data_size());
EXPECT_EQ(expected, actual);
}
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
CencProtectionSchemes, CencProtectionSchemes,
EncryptionHandlerEncryptionTest, EncryptionHandlerEncryptionTest,
@ -698,10 +598,7 @@ INSTANTIATE_TEST_CASE_P(AppleSampleAes,
Values(kCodecAAC, kCodecH264), Values(kCodecAAC, kCodecH264),
Values(kVp9SubsampleEncryption))); Values(kVp9SubsampleEncryption)));
class EncryptionHandlerTrackTypeTest : public EncryptionHandlerTest { class EncryptionHandlerTrackTypeTest : public EncryptionHandlerTest {};
public:
void SetUp() override {}
};
TEST_F(EncryptionHandlerTrackTypeTest, AudioTrackType) { TEST_F(EncryptionHandlerTrackTypeTest, AudioTrackType) {
EncryptionParams::EncryptedStreamAttributes captured_stream_attributes; EncryptionParams::EncryptedStreamAttributes captured_stream_attributes;

View File

@ -9,6 +9,8 @@
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "packager/media/base/mock_aes_cryptor.h"
using ::testing::_; using ::testing::_;
using ::testing::Invoke; using ::testing::Invoke;
using ::testing::Return; using ::testing::Return;
@ -16,21 +18,6 @@ using ::testing::Return;
namespace shaka { namespace shaka {
namespace media { namespace media {
class MockAesCryptor : public AesCryptor {
public:
MockAesCryptor() : AesCryptor(kDontUseConstantIv) {}
MOCK_METHOD2(InitializeWithIv,
bool(const std::vector<uint8_t>& key,
const std::vector<uint8_t>& iv));
MOCK_METHOD4(CryptInternal,
bool(const uint8_t* text,
size_t text_size,
uint8_t* crypt_text,
size_t* crypt_text_size));
MOCK_METHOD0(SetIvInternal, void());
};
class SampleAesEc3CryptorTest : public ::testing::Test { class SampleAesEc3CryptorTest : public ::testing::Test {
public: public:
SampleAesEc3CryptorTest() SampleAesEc3CryptorTest()