diff --git a/packager/media/base/decryptor_source.cc b/packager/media/base/decryptor_source.cc index fb18f2cfde..ef1a9235fe 100644 --- a/packager/media/base/decryptor_source.cc +++ b/packager/media/base/decryptor_source.cc @@ -10,6 +10,18 @@ #include "packager/media/base/aes_decryptor.h" #include "packager/media/base/aes_pattern_cryptor.h" +namespace { +// Return true if [encrypted_buffer, encrypted_buffer + buffer_size) overlaps +// with [decrypted_buffer, decrypted_buffer + buffer_size). +bool CheckMemoryOverlap(const uint8_t* encrypted_buffer, + size_t buffer_size, + uint8_t* decrypted_buffer) { + return (decrypted_buffer < encrypted_buffer) + ? (encrypted_buffer < decrypted_buffer + buffer_size) + : (decrypted_buffer < encrypted_buffer + buffer_size); +} +} // namespace + namespace shaka { namespace media { @@ -20,10 +32,17 @@ DecryptorSource::DecryptorSource(KeySource* key_source) DecryptorSource::~DecryptorSource() {} bool DecryptorSource::DecryptSampleBuffer(const DecryptConfig* decrypt_config, - uint8_t* buffer, - size_t buffer_size) { + const uint8_t* encrypted_buffer, + size_t buffer_size, + uint8_t* decrypted_buffer) { DCHECK(decrypt_config); - DCHECK(buffer); + DCHECK(encrypted_buffer); + DCHECK(decrypted_buffer); + + if (CheckMemoryOverlap(encrypted_buffer, buffer_size, decrypted_buffer)) { + LOG(ERROR) << "Encrypted buffer and decrypted buffer cannot overlap."; + return false; + } // Get the decryptor object. AesCryptor* decryptor = nullptr; @@ -83,7 +102,7 @@ bool DecryptorSource::DecryptSampleBuffer(const DecryptConfig* decrypt_config, if (decrypt_config->subsamples().empty()) { // Sample not encrypted using subsample encryption. Decrypt whole. - if (!decryptor->Crypt(buffer, buffer_size, buffer)) { + if (!decryptor->Crypt(encrypted_buffer, buffer_size, decrypted_buffer)) { LOG(ERROR) << "Error during bulk sample decryption."; return false; } @@ -92,20 +111,24 @@ bool DecryptorSource::DecryptSampleBuffer(const DecryptConfig* decrypt_config, // Subsample decryption. const std::vector& subsamples = decrypt_config->subsamples(); - uint8_t* current_ptr = buffer; - const uint8_t* const buffer_end = buffer + buffer_size; + const uint8_t* current_ptr = encrypted_buffer; + const uint8_t* const buffer_end = encrypted_buffer + buffer_size; for (const auto& subsample : subsamples) { if ((current_ptr + subsample.clear_bytes + subsample.cipher_bytes) > buffer_end) { LOG(ERROR) << "Subsamples overflow sample buffer."; return false; } + memcpy(decrypted_buffer, current_ptr, subsample.clear_bytes); current_ptr += subsample.clear_bytes; - if (!decryptor->Crypt(current_ptr, subsample.cipher_bytes, current_ptr)) { + decrypted_buffer += subsample.clear_bytes; + if (!decryptor->Crypt(current_ptr, subsample.cipher_bytes, + decrypted_buffer)) { LOG(ERROR) << "Error decrypting subsample buffer."; return false; } current_ptr += subsample.cipher_bytes; + decrypted_buffer += subsample.cipher_bytes; } return true; } diff --git a/packager/media/base/decryptor_source.h b/packager/media/base/decryptor_source.h index a4ac9dafbf..6d7f8e962b 100644 --- a/packager/media/base/decryptor_source.h +++ b/packager/media/base/decryptor_source.h @@ -21,12 +21,24 @@ namespace media { /// DecryptorSource wraps KeySource and is responsible for decryptor management. class DecryptorSource { public: + /// Constructs a DecryptorSource object. + /// @param key_source points to the key source that contains the keys. explicit DecryptorSource(KeySource* key_source); ~DecryptorSource(); + /// Decrypt encrypted buffer. + /// @param decrypt_config contains decrypt configuration, e.g. protection + /// scheme, subsample information etc. + /// @param encrypted_buffer points to the encrypted buffer that is to be + /// decrypted. It should not overlap with @a decrypted_buffer. + /// @param buffer_size is the size of encrypted buffer and decrypted buffer. + /// @param decrypted_buffer points to the decrypted buffer. It should not + /// overlap with @a encrypted_buffer. + /// @return true if success, false otherwise. bool DecryptSampleBuffer(const DecryptConfig* decrypt_config, - uint8_t* buffer, - size_t buffer_size); + const uint8_t* encrypted_buffer, + size_t buffer_size, + uint8_t* decrypted_buffer); private: KeySource* key_source_; diff --git a/packager/media/base/decryptor_source_unittest.cc b/packager/media/base/decryptor_source_unittest.cc index e0eba2d55b..da242164a2 100644 --- a/packager/media/base/decryptor_source_unittest.cc +++ b/packager/media/base/decryptor_source_unittest.cc @@ -64,13 +64,15 @@ class DecryptorSourceTest : public ::testing::Test { DecryptorSourceTest() : decryptor_source_(&mock_key_source_), key_id_(std::vector(kKeyId, kKeyId + arraysize(kKeyId))), - buffer_(std::vector(kBuffer, kBuffer + arraysize(kBuffer))) {} + encrypted_buffer_(kBuffer, kBuffer + arraysize(kBuffer)), + decrypted_buffer_(arraysize(kBuffer)) {} protected: StrictMock mock_key_source_; DecryptorSource decryptor_source_; std::vector key_id_; - std::vector buffer_; + std::vector encrypted_buffer_; + std::vector decrypted_buffer_; }; TEST_F(DecryptorSourceTest, FullSampleDecryption) { @@ -83,24 +85,27 @@ TEST_F(DecryptorSourceTest, FullSampleDecryption) { std::vector(kIv, kIv + arraysize(kIv)), std::vector()); ASSERT_TRUE(decryptor_source_.DecryptSampleBuffer( - &decrypt_config, &buffer_[0], buffer_.size())); + &decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(), + &decrypted_buffer_[0])); EXPECT_EQ(std::vector( kExpectedDecryptedBuffer, kExpectedDecryptedBuffer + arraysize(kExpectedDecryptedBuffer)), - buffer_); + decrypted_buffer_); // DecryptSampleBuffer can be called repetitively. No GetKey call again with // the same key id. - buffer_.assign(kBuffer2, kBuffer2 + arraysize(kBuffer2)); + encrypted_buffer_.assign(kBuffer2, kBuffer2 + arraysize(kBuffer2)); + decrypted_buffer_.resize(arraysize(kBuffer2)); DecryptConfig decrypt_config2( key_id_, std::vector(kIv2, kIv2 + arraysize(kIv2)), std::vector()); ASSERT_TRUE(decryptor_source_.DecryptSampleBuffer( - &decrypt_config2, &buffer_[0], buffer_.size())); + &decrypt_config2, &encrypted_buffer_[0], encrypted_buffer_.size(), + &decrypted_buffer_[0])); EXPECT_EQ(std::vector(kExpectedDecryptedBuffer2, kExpectedDecryptedBuffer2 + arraysize(kExpectedDecryptedBuffer2)), - buffer_); + decrypted_buffer_); } TEST_F(DecryptorSourceTest, SubsampleDecryption) { @@ -131,11 +136,12 @@ TEST_F(DecryptorSourceTest, SubsampleDecryption) { std::vector(kSubsamples, kSubsamples + arraysize(kSubsamples))); ASSERT_TRUE(decryptor_source_.DecryptSampleBuffer( - &decrypt_config, &buffer_[0], buffer_.size())); + &decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(), + &decrypted_buffer_[0])); EXPECT_EQ(std::vector( kExpectedDecryptedBuffer, kExpectedDecryptedBuffer + arraysize(kExpectedDecryptedBuffer)), - buffer_); + decrypted_buffer_); } TEST_F(DecryptorSourceTest, SubsampleDecryptionSizeValidation) { @@ -155,7 +161,8 @@ TEST_F(DecryptorSourceTest, SubsampleDecryptionSizeValidation) { std::vector(kSubsamples, kSubsamples + arraysize(kSubsamples))); ASSERT_FALSE(decryptor_source_.DecryptSampleBuffer( - &decrypt_config, &buffer_[0], buffer_.size())); + &decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(), + &decrypted_buffer_[0])); } TEST_F(DecryptorSourceTest, DecryptFailedIfGetKeyFailed) { @@ -166,7 +173,17 @@ TEST_F(DecryptorSourceTest, DecryptFailedIfGetKeyFailed) { std::vector(kIv, kIv + arraysize(kIv)), std::vector()); ASSERT_FALSE(decryptor_source_.DecryptSampleBuffer( - &decrypt_config, &buffer_[0], buffer_.size())); + &decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(), + &decrypted_buffer_[0])); +} + +TEST_F(DecryptorSourceTest, EncryptedBufferAndDecryptedBufferOverlap) { + DecryptConfig decrypt_config(key_id_, + std::vector(kIv, kIv + arraysize(kIv)), + std::vector()); + ASSERT_FALSE(decryptor_source_.DecryptSampleBuffer( + &decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(), + &encrypted_buffer_[5])); } } // namespace media diff --git a/packager/media/base/media_handler_test_base.cc b/packager/media/base/media_handler_test_base.cc index c2ff0aec50..c54c8246c8 100644 --- a/packager/media/base/media_handler_test_base.cc +++ b/packager/media/base/media_handler_test_base.cc @@ -154,21 +154,21 @@ std::unique_ptr MediaHandlerTestBase::GetAudioStreamInfo( !kEncrypted)); } -std::unique_ptr MediaHandlerTestBase::GetMediaSample( +std::shared_ptr MediaHandlerTestBase::GetMediaSample( int64_t timestamp, int64_t duration, bool is_keyframe) const { return GetMediaSample(timestamp, duration, is_keyframe, kData, sizeof(kData)); } -std::unique_ptr MediaHandlerTestBase::GetMediaSample( +std::shared_ptr MediaHandlerTestBase::GetMediaSample( int64_t timestamp, int64_t duration, bool is_keyframe, const uint8_t* data, size_t data_length) const { - std::unique_ptr sample( - new MediaSample(data, data_length, nullptr, 0, is_keyframe)); + std::shared_ptr sample = + MediaSample::CopyFrom(data, data_length, nullptr, 0, is_keyframe); sample->set_dts(timestamp); sample->set_duration(duration); diff --git a/packager/media/base/media_handler_test_base.h b/packager/media/base/media_handler_test_base.h index 5a21367d9d..098bfcae97 100644 --- a/packager/media/base/media_handler_test_base.h +++ b/packager/media/base/media_handler_test_base.h @@ -150,11 +150,11 @@ class MediaHandlerTestBase : public ::testing::Test { std::unique_ptr GetAudioStreamInfo(uint32_t time_scale, Codec codec) const; - std::unique_ptr GetMediaSample(int64_t timestamp, + std::shared_ptr GetMediaSample(int64_t timestamp, int64_t duration, bool is_keyframe) const; - std::unique_ptr GetMediaSample(int64_t timestamp, + std::shared_ptr GetMediaSample(int64_t timestamp, int64_t duration, bool is_keyframe, const uint8_t* data, diff --git a/packager/media/base/media_sample.cc b/packager/media/base/media_sample.cc index f055705ee7..d5b8df623b 100644 --- a/packager/media/base/media_sample.cc +++ b/packager/media/base/media_sample.cc @@ -15,29 +15,26 @@ namespace shaka { namespace media { MediaSample::MediaSample(const uint8_t* data, - size_t size, + size_t data_size, const uint8_t* side_data, size_t side_data_size, bool is_key_frame) - : dts_(0), - pts_(0), - duration_(0), - is_key_frame_(is_key_frame), - is_encrypted_(false) { + : is_key_frame_(is_key_frame) { if (!data) { - CHECK_EQ(size, 0u); + CHECK_EQ(data_size, 0u); } - data_.assign(data, data + size); - if (side_data) - side_data_.assign(side_data, side_data + side_data_size); + SetData(data, data_size); + if (side_data) { + std::shared_ptr shared_side_data(new uint8_t[side_data_size], + std::default_delete()); + memcpy(shared_side_data.get(), side_data, side_data_size); + side_data_ = std::move(shared_side_data); + side_data_size_ = side_data_size; + } } -MediaSample::MediaSample() : dts_(0), - pts_(0), - duration_(0), - is_key_frame_(false), - is_encrypted_(false) {} +MediaSample::MediaSample() {} MediaSample::~MediaSample() {} @@ -47,8 +44,8 @@ std::shared_ptr MediaSample::CopyFrom(const uint8_t* data, bool is_key_frame) { // If you hit this CHECK you likely have a bug in a demuxer. Go fix it. CHECK(data); - return std::make_shared(data, data_size, nullptr, 0u, - is_key_frame); + return std::shared_ptr( + new MediaSample(data, data_size, nullptr, 0u, is_key_frame)); } // static @@ -59,32 +56,32 @@ std::shared_ptr MediaSample::CopyFrom(const uint8_t* data, bool is_key_frame) { // If you hit this CHECK you likely have a bug in a demuxer. Go fix it. CHECK(data); - return std::make_shared(data, data_size, side_data, - side_data_size, is_key_frame); + return std::shared_ptr(new MediaSample( + data, data_size, side_data, side_data_size, is_key_frame)); } // static std::shared_ptr MediaSample::CopyFrom( const MediaSample& media_sample) { - std::shared_ptr new_media_sample = CopyFrom( - media_sample.data(), media_sample.data_size(), media_sample.side_data(), - media_sample.side_data_size(), media_sample.is_key_frame()); - - new_media_sample->set_dts(media_sample.dts()); - new_media_sample->set_pts(media_sample.pts()); - new_media_sample->set_is_encrypted(media_sample.is_encrypted()); - new_media_sample->set_config_id(media_sample.config_id()); - new_media_sample->set_duration(media_sample.duration()); - - if (media_sample.decrypt_config()) { - std::unique_ptr decrypt_config( - new DecryptConfig(media_sample.decrypt_config()->key_id(), - media_sample.decrypt_config()->iv(), - media_sample.decrypt_config()->subsamples(), - media_sample.decrypt_config()->protection_scheme(), - media_sample.decrypt_config()->crypt_byte_block(), - media_sample.decrypt_config()->skip_byte_block())); - new_media_sample->set_decrypt_config(std::move(decrypt_config)); + std::shared_ptr new_media_sample(new MediaSample); + new_media_sample->dts_ = media_sample.dts_; + new_media_sample->pts_ = media_sample.pts_; + new_media_sample->duration_ = media_sample.duration_; + new_media_sample->is_key_frame_ = media_sample.is_key_frame_; + new_media_sample->is_encrypted_ = media_sample.is_encrypted_; + new_media_sample->data_ = media_sample.data_; + new_media_sample->data_size_ = media_sample.data_size_; + new_media_sample->side_data_ = media_sample.side_data_; + new_media_sample->side_data_size_ = media_sample.side_data_size_; + new_media_sample->config_id_ = media_sample.config_id_; + if (media_sample.decrypt_config_) { + new_media_sample->decrypt_config_.reset( + new DecryptConfig(media_sample.decrypt_config_->key_id(), + media_sample.decrypt_config_->iv(), + media_sample.decrypt_config_->subsamples(), + media_sample.decrypt_config_->protection_scheme(), + media_sample.decrypt_config_->crypt_byte_block(), + media_sample.decrypt_config_->skip_byte_block())); } return new_media_sample; } @@ -92,32 +89,43 @@ std::shared_ptr MediaSample::CopyFrom( // static std::shared_ptr MediaSample::FromMetadata(const uint8_t* metadata, size_t metadata_size) { - return std::make_shared(nullptr, 0, metadata, metadata_size, - false); + return std::shared_ptr( + new MediaSample(nullptr, 0, metadata, metadata_size, false)); } // static std::shared_ptr MediaSample::CreateEmptyMediaSample() { - return std::make_shared(); + return std::shared_ptr(new MediaSample); } // static std::shared_ptr MediaSample::CreateEOSBuffer() { - return std::make_shared(nullptr, 0, nullptr, 0, false); + return std::shared_ptr( + new MediaSample(nullptr, 0, nullptr, 0, false)); +} + +void MediaSample::TransferData(std::shared_ptr data, + size_t data_size) { + data_ = std::move(data); + data_size_ = data_size; +} + +void MediaSample::SetData(const uint8_t* data, size_t data_size) { + std::shared_ptr shared_data(new uint8_t[data_size], + std::default_delete()); + memcpy(shared_data.get(), data, data_size); + TransferData(std::move(shared_data), data_size); } std::string MediaSample::ToString() const { if (end_of_stream()) return "End of stream sample\n"; return base::StringPrintf( - "dts: %" PRId64 "\n pts: %" PRId64 "\n duration: %" PRId64 "\n " + "dts: %" PRId64 "\n pts: %" PRId64 "\n duration: %" PRId64 + "\n " "is_key_frame: %s\n size: %zu\n side_data_size: %zu\n", - dts_, - pts_, - duration_, - is_key_frame_ ? "true" : "false", - data_.size(), - side_data_.size()); + dts_, pts_, duration_, is_key_frame_ ? "true" : "false", data_size_, + side_data_size_); } } // namespace media diff --git a/packager/media/base/media_sample.h b/packager/media/base/media_sample.h index 601cab674a..3a8a30bfd0 100644 --- a/packager/media/base/media_sample.h +++ b/packager/media/base/media_sample.h @@ -66,17 +66,22 @@ class MediaSample { /// is disallowed. static std::shared_ptr CreateEOSBuffer(); - // Create a MediaSample. Buffer will be padded and aligned as necessary. - // |data|,|side_data| can be NULL, which indicates an empty sample. - // |size|,|side_data_size| should not be negative. - MediaSample(const uint8_t* data, - size_t size, - const uint8_t* side_data, - size_t side_data_size, - bool is_key_frame); - MediaSample(); virtual ~MediaSample(); + /// Transfer data to this media sample. No data copying is involved. + /// @param data points to the data to be transferred. + /// @param data_size is the size of the data to be transferred. + void TransferData(std::shared_ptr data, size_t data_size); + + /// Set the data in this media sample. Note that this method involves data + /// copying. + /// @param data points to the data to be copied. + /// @param data_size is the size of the data to be copied. + void SetData(const uint8_t* data, size_t data_size); + + /// @return a human-readable string describing |*this|. + std::string ToString() const; + int64_t dts() const { DCHECK(!end_of_stream()); return dts_; @@ -112,38 +117,19 @@ class MediaSample { } const uint8_t* data() const { DCHECK(!end_of_stream()); - return data_.data(); - } - - uint8_t* writable_data() { - DCHECK(!end_of_stream()); - return data_.data(); + return data_.get(); } size_t data_size() const { DCHECK(!end_of_stream()); - return data_.size(); + return data_size_; } - const uint8_t* side_data() const { - return side_data_.data(); - } + const uint8_t* side_data() const { return side_data_.get(); } - size_t side_data_size() const { - return side_data_.size(); - } + size_t side_data_size() const { return side_data_size_; } - const DecryptConfig* decrypt_config() const { - return decrypt_config_.get(); - } - - void set_data(const uint8_t* data, const size_t data_size) { - data_.assign(data, data + data_size); - } - - void resize_data(const size_t data_size) { - data_.resize(data_size); - } + const DecryptConfig* decrypt_config() const { return decrypt_config_.get(); } void set_is_key_frame(bool value) { is_key_frame_ = value; @@ -158,32 +144,42 @@ class MediaSample { } // If there's no data in this buffer, it represents end of stream. - bool end_of_stream() const { return data_.size() == 0; } + bool end_of_stream() const { return data_size_ == 0; } const std::string& config_id() const { return config_id_; } void set_config_id(const std::string& config_id) { config_id_ = config_id; } - /// @return a human-readable string describing |*this|. - std::string ToString() const; + protected: + // Made it protected to disallow the constructor to be called directly. + // Create a MediaSample. Buffer will be padded and aligned as necessary. + // |data|,|side_data| can be nullptr, which indicates an empty sample. + MediaSample(const uint8_t* data, + size_t data_size, + const uint8_t* side_data, + size_t side_data_size, + bool is_key_frame); + MediaSample(); private: // Decoding time stamp. - int64_t dts_; + int64_t dts_ = 0; // Presentation time stamp. - int64_t pts_; - int64_t duration_; - bool is_key_frame_; + int64_t pts_ = 0; + int64_t duration_ = 0; + bool is_key_frame_ = false; // is sample encrypted ? - bool is_encrypted_; + bool is_encrypted_ = false; // Main buffer data. - std::vector data_; + std::shared_ptr data_; + size_t data_size_ = 0; // Contain additional buffers to complete the main one. Needed by WebM // http://www.matroska.org/technical/specs/index.html BlockAdditional[A5]. // Not used by mp4 and other containers. - std::vector side_data_; + std::shared_ptr side_data_; + size_t side_data_size_ = 0; // Text specific fields. // For now this is the cue identifier for WebVTT. diff --git a/packager/media/crypto/encryption_handler.cc b/packager/media/crypto/encryption_handler.cc index 095b37515f..7fb871ead9 100644 --- a/packager/media/crypto/encryption_handler.cc +++ b/packager/media/crypto/encryption_handler.cc @@ -9,6 +9,7 @@ #include #include +#include #include #include "packager/media/base/aes_encryptor.h" @@ -259,49 +260,47 @@ Status EncryptionHandler::ProcessMediaSample( // in-place. std::shared_ptr cipher_sample = MediaSample::CopyFrom(*clear_sample); + // |cipher_sample| above still contains the old clear sample data. We will + // use |cipher_sample_data| to hold cipher sample data then transfer it to + // |cipher_sample| after encryption. + std::shared_ptr cipher_sample_data( + new uint8_t[clear_sample->data_size()], std::default_delete()); - Status result; if (vpx_parser_) { - if (EncryptVpxFrame(vpx_frames, - cipher_sample->writable_data(), - cipher_sample->data_size(), - decrypt_config.get())) { - DCHECK_EQ(decrypt_config->GetTotalSizeOfSubsamples(), - cipher_sample->data_size()); - } else { - result = Status( - error::ENCRYPTION_FAILURE, - "Failed to encrypt VPX frame."); + if (!EncryptVpxFrame(vpx_frames, clear_sample->data(), + clear_sample->data_size(), + &cipher_sample_data.get()[0], decrypt_config.get())) { + return Status(error::ENCRYPTION_FAILURE, "Failed to encrypt VPX frame."); } + DCHECK_EQ(decrypt_config->GetTotalSizeOfSubsamples(), + clear_sample->data_size()); } else if (header_parser_) { - if (EncryptNalFrame(cipher_sample->writable_data(), - cipher_sample->data_size(), - decrypt_config.get())) { - DCHECK_EQ(decrypt_config->GetTotalSizeOfSubsamples(), - cipher_sample->data_size()); - } else { - result = Status( - error::ENCRYPTION_FAILURE, - "Failed to encrypt NAL frame."); + if (!EncryptNalFrame(clear_sample->data(), clear_sample->data_size(), + &cipher_sample_data.get()[0], decrypt_config.get())) { + return Status(error::ENCRYPTION_FAILURE, "Failed to encrypt NAL frame."); + } + DCHECK_EQ(decrypt_config->GetTotalSizeOfSubsamples(), + clear_sample->data_size()); + } else { + memcpy(&cipher_sample_data.get()[0], clear_sample->data(), + std::min(clear_sample->data_size(), leading_clear_bytes_size_)); + if (clear_sample->data_size() > leading_clear_bytes_size_) { + EncryptBytes(clear_sample->data() + leading_clear_bytes_size_, + clear_sample->data_size() - leading_clear_bytes_size_, + &cipher_sample_data.get()[leading_clear_bytes_size_]); } - } else if (cipher_sample->data_size() > leading_clear_bytes_size_) { - EncryptBytes( - cipher_sample->writable_data() + leading_clear_bytes_size_, - cipher_sample->data_size() - leading_clear_bytes_size_); } - if (!result.ok()) { - return result; - } - - encryptor_->UpdateIv(); - + cipher_sample->TransferData(std::move(cipher_sample_data), + clear_sample->data_size()); // Finish initializing the sample before sending it downstream. We must - // wait until now to finish the initialization as we will loose access to + // wait until now to finish the initialization as we will lose access to // |decrypt_config| once we set it. cipher_sample->set_is_encrypted(true); cipher_sample->set_decrypt_config(std::move(decrypt_config)); + encryptor_->UpdateIv(); + return DispatchMediaSample(kStreamIndex, std::move(cipher_sample)); } @@ -433,10 +432,11 @@ bool EncryptionHandler::CreateEncryptor(const EncryptionKey& encryption_key) { bool EncryptionHandler::EncryptVpxFrame( const std::vector& vpx_frames, - uint8_t* source, + const uint8_t* source, size_t source_size, + uint8_t* dest, DecryptConfig* decrypt_config) { - uint8_t* data = source; + const uint8_t* data = source; for (const VPxFrameInfo& frame : vpx_frames) { uint16_t clear_bytes = static_cast(frame.uncompressed_header_size); @@ -456,9 +456,11 @@ bool EncryptionHandler::EncryptVpxFrame( cipher_bytes -= misalign_bytes; decrypt_config->AddSubsample(clear_bytes, cipher_bytes); + memcpy(dest, data, clear_bytes); if (cipher_bytes > 0) - EncryptBytes(data + clear_bytes, cipher_bytes); + EncryptBytes(data + clear_bytes, cipher_bytes, dest + clear_bytes); data += frame.frame_size; + dest += frame.frame_size; } // Add subsample for the superframe index if exists. const bool is_superframe = vpx_frames.size() > 1; @@ -469,18 +471,20 @@ bool EncryptionHandler::EncryptVpxFrame( uint16_t clear_bytes = static_cast(index_size); uint32_t cipher_bytes = 0; decrypt_config->AddSubsample(clear_bytes, cipher_bytes); + memcpy(dest, data, clear_bytes); } return true; } -bool EncryptionHandler::EncryptNalFrame(uint8_t* data, - size_t data_length, +bool EncryptionHandler::EncryptNalFrame(const uint8_t* source, + size_t source_size, + uint8_t* dest, DecryptConfig* decrypt_config) { DCHECK_NE(nalu_length_size_, 0u); DCHECK(header_parser_); const Nalu::CodecType nalu_type = (codec_ == kCodecH265) ? Nalu::kH265 : Nalu::kH264; - NaluReader reader(nalu_type, nalu_length_size_, data, data_length); + NaluReader reader(nalu_type, nalu_length_size_, source, source_size); // Store the current length of clear data. This is used to squash // multiple unencrypted NAL units into fewer subsample entries. @@ -519,13 +523,17 @@ bool EncryptionHandler::EncryptNalFrame(uint8_t* data, cipher_bytes -= misalign_bytes; } - const uint8_t* nalu_data = nalu.data() + current_clear_bytes; - EncryptBytes(const_cast(nalu_data), cipher_bytes); - - AddSubsample( - accumulated_clear_bytes + nalu_length_size_ + current_clear_bytes, - cipher_bytes, decrypt_config); + accumulated_clear_bytes += nalu_length_size_ + current_clear_bytes; + AddSubsample(accumulated_clear_bytes, cipher_bytes, decrypt_config); + memcpy(dest, source, accumulated_clear_bytes); + source += accumulated_clear_bytes; + dest += accumulated_clear_bytes; accumulated_clear_bytes = 0; + + DCHECK_EQ(nalu.data() + current_clear_bytes, source); + EncryptBytes(source, cipher_bytes, dest); + source += cipher_bytes; + dest += cipher_bytes; } else { // For non-video-slice or small NAL units, don't encrypt. accumulated_clear_bytes += nalu_length_size_ + nalu_total_size; @@ -536,13 +544,17 @@ bool EncryptionHandler::EncryptNalFrame(uint8_t* data, return false; } AddSubsample(accumulated_clear_bytes, 0, decrypt_config); + memcpy(dest, source, accumulated_clear_bytes); return true; } -void EncryptionHandler::EncryptBytes(uint8_t* data, size_t size) { - DCHECK(data); +void EncryptionHandler::EncryptBytes(const uint8_t* source, + size_t source_size, + uint8_t* dest) { + DCHECK(source); + DCHECK(dest); DCHECK(encryptor_); - CHECK(encryptor_->Crypt(data, size, data)); + CHECK(encryptor_->Crypt(source, source_size, dest)); } void EncryptionHandler::InjectVpxParserForTesting( diff --git a/packager/media/crypto/encryption_handler.h b/packager/media/crypto/encryption_handler.h index cd93a89869..bbe0ee7a3f 100644 --- a/packager/media/crypto/encryption_handler.h +++ b/packager/media/crypto/encryption_handler.h @@ -47,15 +47,22 @@ class EncryptionHandler : public MediaHandler { Status SetupProtectionPattern(StreamType stream_type); bool CreateEncryptor(const EncryptionKey& encryption_key); + // Encrypt a VPx frame with size |source_size|. |dest| should have at least + // |source_size| bytes. bool EncryptVpxFrame(const std::vector& vpx_frames, - uint8_t* source, + const uint8_t* source, size_t source_size, + uint8_t* dest, DecryptConfig* decrypt_config); - bool EncryptNalFrame(uint8_t* data, - size_t data_length, + // Encrypt a NAL unit frame with size |source_size|. |dest| should have at + // least |source_size| bytes. + bool EncryptNalFrame(const uint8_t* source, + size_t source_size, + uint8_t* dest, DecryptConfig* decrypt_config); - void EncryptBytes(uint8_t* data, - size_t size); + // Encrypt an array with size |source_size|. |dest| should have at + // least |source_size| bytes. + void EncryptBytes(const uint8_t* source, size_t source_size, uint8_t* dest); // Testing injections. void InjectVpxParserForTesting(std::unique_ptr vpx_parser); diff --git a/packager/media/crypto/encryption_handler_unittest.cc b/packager/media/crypto/encryption_handler_unittest.cc index 5874c8f9e7..376a7a79f1 100644 --- a/packager/media/crypto/encryption_handler_unittest.cc +++ b/packager/media/crypto/encryption_handler_unittest.cc @@ -178,6 +178,10 @@ const uint8_t kData[]{ // Third non-video-slice NALU for H264 or superframe index for VP9. 0x06, 0x67, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, }; +const size_t kDataSize = sizeof(kData); +// A short data size (less than leading clear bytes) for SampleAes audio +// testing. +const size_t kShortDataSize = 14; // H264 subsample information for the the above data. const size_t kNaluLengthSize = 1u; @@ -277,7 +281,7 @@ class EncryptionHandlerEncryptionTest case kCodecVP9: if (vp9_subsample_encryption_) { std::unique_ptr mock_vpx_parser(new MockVpxParser); - EXPECT_CALL(*mock_vpx_parser, Parse(_, sizeof(kData), _)) + EXPECT_CALL(*mock_vpx_parser, Parse(_, kDataSize, _)) .WillRepeatedly( DoAll(SetArgPointee<2>(GetMockVpxFrameInfo()), Return(true))); InjectVpxParserForTesting(std::move(mock_vpx_parser)); @@ -496,13 +500,8 @@ TEST_P(EncryptionHandlerEncryptionTest, ClearLeadWithNoKeyRotation) { for (int i = 0; i < 3; ++i) { // Use single-frame segment for testing. ASSERT_OK(Process(StreamData::FromMediaSample( - kStreamIndex, - GetMediaSample( - i * kSegmentDuration, - kSegmentDuration, - kIsKeyFrame, - kData, - sizeof(kData))))); + kStreamIndex, GetMediaSample(i * kSegmentDuration, kSegmentDuration, + kIsKeyFrame, kData, kDataSize)))); ASSERT_OK(Process(StreamData::FromSegmentInfo( kStreamIndex, GetSegmentInfo(i * kSegmentDuration, kSegmentDuration, !kIsSubsegment)))); @@ -568,13 +567,8 @@ TEST_P(EncryptionHandlerEncryptionTest, ClearLeadWithKeyRotation) { } // Use single-frame segment for testing. ASSERT_OK(Process(StreamData::FromMediaSample( - kStreamIndex, - GetMediaSample( - i * kSegmentDuration, - kSegmentDuration, - kIsKeyFrame, - kData, - sizeof(kData))))); + kStreamIndex, GetMediaSample(i * kSegmentDuration, kSegmentDuration, + kIsKeyFrame, kData, kDataSize)))); ASSERT_OK(Process(StreamData::FromSegmentInfo( kStreamIndex, GetSegmentInfo(i * kSegmentDuration, kSegmentDuration, !kIsSubsegment)))); @@ -625,20 +619,9 @@ TEST_P(EncryptionHandlerEncryptionTest, Encrypt) { InjectCodecParser(); - std::unique_ptr stream_data(new StreamData); - stream_data->stream_index = 0; - stream_data->stream_data_type = StreamDataType::kMediaSample; - stream_data->media_sample.reset( - new MediaSample(kData, sizeof(kData), nullptr, 0, kIsKeyFrame)); - ASSERT_OK(Process(StreamData::FromMediaSample( kStreamIndex, - GetMediaSample( - 0, - kSampleDuration, - kIsKeyFrame, - kData, - sizeof(kData))))); + GetMediaSample(0, kSampleDuration, kIsKeyFrame, kData, kDataSize)))); ASSERT_EQ(2u, GetOutputStreamDataVector().size()); ASSERT_EQ(kStreamIndex, GetOutputStreamDataVector().back()->stream_index); ASSERT_EQ(StreamDataType::kMediaSample, @@ -654,16 +637,52 @@ TEST_P(EncryptionHandlerEncryptionTest, Encrypt) { EXPECT_EQ(GetExpectedCryptByteBlock(), decrypt_config->crypt_byte_block()); EXPECT_EQ(GetExpectedSkipByteBlock(), decrypt_config->skip_byte_block()); - std::vector expected( - kData, - kData + sizeof(kData)); - std::vector actual( - media_sample->data(), - media_sample->data() + media_sample->data_size()); + std::vector expected(kData, kData + kDataSize); + std::vector actual(media_sample->data(), + media_sample->data() + media_sample->data_size()); ASSERT_TRUE(Decrypt(*decrypt_config, actual.data(), actual.size())); EXPECT_EQ(expected, actual); } +// Verify that the data in short audio (less than leading clear bytes) is left +// unencrypted. +TEST_P(EncryptionHandlerEncryptionTest, SampleAesEncryptShortAudio) { + if (IsVideoCodec(codec_) || + protection_scheme_ != kAppleSampleAesProtectionScheme) { + return; + } + EncryptionParams encryption_params; + encryption_params.protection_scheme = kAppleSampleAesProtectionScheme; + SetUpEncryptionHandler(encryption_params); + + const EncryptionKey mock_encryption_key = GetMockEncryptionKey(); + EXPECT_CALL(mock_key_source_, GetKey(_, _)) + .WillOnce( + DoAll(SetArgPointee<1>(mock_encryption_key), Return(Status::OK))); + + ASSERT_OK(Process(StreamData::FromStreamInfo( + kStreamIndex, GetAudioStreamInfo(kTimeScale, codec_)))); + + ASSERT_OK(Process(StreamData::FromMediaSample( + kStreamIndex, + GetMediaSample(0, kSampleDuration, kIsKeyFrame, kData, kShortDataSize)))); + ASSERT_EQ(2u, GetOutputStreamDataVector().size()); + ASSERT_EQ(kStreamIndex, GetOutputStreamDataVector().back()->stream_index); + ASSERT_EQ(StreamDataType::kMediaSample, + GetOutputStreamDataVector().back()->stream_data_type); + + auto* media_sample = GetOutputStreamDataVector().back()->media_sample.get(); + auto* decrypt_config = media_sample->decrypt_config(); + EXPECT_TRUE(decrypt_config->subsamples().empty()); + EXPECT_EQ(kAppleSampleAesProtectionScheme, + decrypt_config->protection_scheme()); + + std::vector expected(kData, kData + kShortDataSize); + std::vector actual(media_sample->data(), + media_sample->data() + media_sample->data_size()); + EXPECT_EQ(expected, actual); +} + INSTANTIATE_TEST_CASE_P( CencProtectionSchemes, EncryptionHandlerEncryptionTest, diff --git a/packager/media/formats/mp4/mp4_media_parser.cc b/packager/media/formats/mp4/mp4_media_parser.cc index f51feaafc1..ad09ab5461 100644 --- a/packager/media/formats/mp4/mp4_media_parser.cc +++ b/packager/media/formats/mp4/mp4_media_parser.cc @@ -708,9 +708,17 @@ bool MP4MediaParser::EnqueueSample(bool* err) { return false; } + const uint8_t* media_data = buf; + const size_t media_data_size = runs_->sample_size(); + // Use a dummy data size of 0 to avoid copying overhead. + // Actual media data is set later. + const size_t kDummyDataSize = 0; std::shared_ptr stream_sample( - MediaSample::CopyFrom(buf, runs_->sample_size(), runs_->is_keyframe())); + MediaSample::CopyFrom(media_data, kDummyDataSize, runs_->is_keyframe())); + if (runs_->is_encrypted()) { + std::shared_ptr decrypted_media_data( + new uint8_t[media_data_size], std::default_delete()); std::unique_ptr decrypt_config = runs_->GetDecryptConfig(); if (!decrypt_config) { *err = true; @@ -719,17 +727,24 @@ bool MP4MediaParser::EnqueueSample(bool* err) { } if (!decryptor_source_) { + stream_sample->SetData(media_data, media_data_size); // If the demuxer does not have the decryptor_source_, store // decrypt_config so that the demuxed sample can be decrypted later. stream_sample->set_decrypt_config(std::move(decrypt_config)); stream_sample->set_is_encrypted(true); - } else if (!decryptor_source_->DecryptSampleBuffer( - decrypt_config.get(), stream_sample->writable_data(), - stream_sample->data_size())) { - *err = true; - LOG(ERROR) << "Cannot decrypt samples."; - return false; + } else { + if (!decryptor_source_->DecryptSampleBuffer(decrypt_config.get(), + media_data, media_data_size, + decrypted_media_data.get())) { + *err = true; + LOG(ERROR) << "Cannot decrypt samples."; + return false; + } + stream_sample->TransferData(std::move(decrypted_media_data), + media_data_size); } + } else { + stream_sample->SetData(media_data, media_data_size); } stream_sample->set_dts(runs_->dts()); diff --git a/packager/media/formats/webm/encryptor.cc b/packager/media/formats/webm/encryptor.cc index ae160d53b8..bb33f54210 100644 --- a/packager/media/formats/webm/encryptor.cc +++ b/packager/media/formats/webm/encryptor.cc @@ -13,6 +13,57 @@ namespace shaka { namespace media { namespace webm { +namespace { +void WriteEncryptedFrameHeader(const DecryptConfig* decrypt_config, + BufferWriter* header_buffer) { + if (decrypt_config) { + const size_t iv_size = decrypt_config->iv().size(); + DCHECK_EQ(iv_size, kWebMIvSize); + if (!decrypt_config->subsamples().empty()) { + const auto& subsamples = decrypt_config->subsamples(); + // Use partitioned subsample encryption: | signal_byte(3) | iv + // | num_partitions | partition_offset * n | enc_data | + DCHECK_LT(subsamples.size(), kWebMMaxSubsamples); + const size_t num_partitions = + 2 * subsamples.size() - 1 - + (subsamples.back().cipher_bytes == 0 ? 1 : 0); + const size_t header_size = kWebMSignalByteSize + iv_size + + kWebMNumPartitionsSize + + (kWebMPartitionOffsetSize * num_partitions); + + const uint8_t signal_byte = kWebMEncryptedSignal | kWebMPartitionedSignal; + header_buffer->AppendInt(signal_byte); + header_buffer->AppendVector(decrypt_config->iv()); + header_buffer->AppendInt(static_cast(num_partitions)); + + uint32_t partition_offset = 0; + for (size_t i = 0; i < subsamples.size() - 1; ++i) { + partition_offset += subsamples[i].clear_bytes; + header_buffer->AppendInt(partition_offset); + partition_offset += subsamples[i].cipher_bytes; + header_buffer->AppendInt(partition_offset); + } + // Add another partition between the clear bytes and cipher bytes if + // cipher bytes is not zero. + if (subsamples.back().cipher_bytes != 0) { + partition_offset += subsamples.back().clear_bytes; + header_buffer->AppendInt(partition_offset); + } + + DCHECK_EQ(header_size, header_buffer->Size()); + } else { + // Use whole-frame encryption: | signal_byte(1) | iv | enc_data | + const uint8_t signal_byte = kWebMEncryptedSignal; + header_buffer->AppendInt(signal_byte); + header_buffer->AppendVector(decrypt_config->iv()); + } + } else { + // Clear sample: | signal_byte(0) | data | + const uint8_t signal_byte = 0x00; + header_buffer->AppendInt(signal_byte); + } +} +} // namespace Status UpdateTrackForEncryption(const std::vector& key_id, mkvmuxer::Track* track) { @@ -46,69 +97,17 @@ Status UpdateTrackForEncryption(const std::vector& key_id, } void UpdateFrameForEncryption(MediaSample* sample) { - const size_t sample_size = sample->data_size(); - if (sample->decrypt_config()) { - auto* decrypt_config = sample->decrypt_config(); - const size_t iv_size = decrypt_config->iv().size(); - DCHECK_EQ(iv_size, kWebMIvSize); - if (!decrypt_config->subsamples().empty()) { - auto& subsamples = decrypt_config->subsamples(); - // Use partitioned subsample encryption: | signal_byte(3) | iv - // | num_partitions | partition_offset * n | enc_data | - DCHECK_LT(subsamples.size(), kWebMMaxSubsamples); - const size_t num_partitions = - 2 * subsamples.size() - 1 - - (subsamples.back().cipher_bytes == 0 ? 1 : 0); - const size_t header_size = kWebMSignalByteSize + iv_size + - kWebMNumPartitionsSize + - (kWebMPartitionOffsetSize * num_partitions); - sample->resize_data(header_size + sample_size); - uint8_t* sample_data = sample->writable_data(); - memmove(sample_data + header_size, sample_data, sample_size); - sample_data[0] = kWebMEncryptedSignal | kWebMPartitionedSignal; - memcpy(sample_data + kWebMSignalByteSize, decrypt_config->iv().data(), - iv_size); - sample_data[kWebMSignalByteSize + kWebMIvSize] = - static_cast(num_partitions); + DCHECK(sample); + BufferWriter header_buffer; + WriteEncryptedFrameHeader(sample->decrypt_config(), &header_buffer); - BufferWriter offsets_buffer; - uint32_t partition_offset = 0; - for (size_t i = 0; i < subsamples.size() - 1; ++i) { - partition_offset += subsamples[i].clear_bytes; - offsets_buffer.AppendInt(partition_offset); - partition_offset += subsamples[i].cipher_bytes; - offsets_buffer.AppendInt(partition_offset); - } - // Add another partition between the clear bytes and cipher bytes if - // cipher bytes is not zero. - if (subsamples.back().cipher_bytes != 0) { - partition_offset += subsamples.back().clear_bytes; - offsets_buffer.AppendInt(partition_offset); - } - DCHECK_EQ(num_partitions * kWebMPartitionOffsetSize, - offsets_buffer.Size()); - memcpy(sample_data + kWebMSignalByteSize + kWebMIvSize + - kWebMNumPartitionsSize, - offsets_buffer.Buffer(), offsets_buffer.Size()); - } else { - // Use whole-frame encryption: | signal_byte(1) | iv | enc_data | - sample->resize_data(sample_size + iv_size + kWebMSignalByteSize); - uint8_t* sample_data = sample->writable_data(); - - // First move the sample data to after the IV; then write the IV and - // signal byte. - memmove(sample_data + iv_size + kWebMSignalByteSize, sample_data, - sample_size); - sample_data[0] = kWebMEncryptedSignal; - memcpy(sample_data + 1, decrypt_config->iv().data(), iv_size); - } - } else { - // Clear sample: | signal_byte(0) | data | - sample->resize_data(sample_size + 1); - uint8_t* sample_data = sample->writable_data(); - memmove(sample_data + 1, sample_data, sample_size); - sample_data[0] = 0x00; - } + const size_t sample_size = header_buffer.Size() + sample->data_size(); + std::shared_ptr new_sample_data(new uint8_t[sample_size], + std::default_delete()); + memcpy(new_sample_data.get(), header_buffer.Buffer(), header_buffer.Size()); + memcpy(&new_sample_data.get()[header_buffer.Size()], sample->data(), + sample->data_size()); + sample->TransferData(std::move(new_sample_data), sample_size); } } // namespace webm diff --git a/packager/media/formats/webm/webm_cluster_parser.cc b/packager/media/formats/webm/webm_cluster_parser.cc index 023c6526da..350342a75f 100644 --- a/packager/media/formats/webm/webm_cluster_parser.cc +++ b/packager/media/formats/webm/webm_cluster_parser.cc @@ -376,21 +376,34 @@ bool WebMClusterParser::OnBlock(bool is_simple_block, return false; } - buffer = MediaSample::CopyFrom(data + data_offset, size - data_offset, - additional, additional_size, is_key_frame); + const uint8_t* media_data = data + data_offset; + const size_t media_data_size = size - data_offset; + // Use a dummy data size of 0 to avoid copying overhead. + // Actual media data is set later. + const size_t kDummyDataSize = 0; + buffer = MediaSample::CopyFrom(media_data, kDummyDataSize, additional, + additional_size, is_key_frame); if (decrypt_config) { if (!decryptor_source_) { + buffer->SetData(media_data, media_data_size); // If the demuxer does not have the decryptor_source_, store // decrypt_config so that the demuxed sample can be decrypted later. buffer->set_decrypt_config(std::move(decrypt_config)); buffer->set_is_encrypted(true); - } else if (!decryptor_source_->DecryptSampleBuffer( - decrypt_config.get(), buffer->writable_data(), - buffer->data_size())) { - LOG(ERROR) << "Cannot decrypt samples"; - return false; + } else { + std::shared_ptr decrypted_media_data( + new uint8_t[media_data_size], std::default_delete()); + if (!decryptor_source_->DecryptSampleBuffer( + decrypt_config.get(), media_data, media_data_size, + decrypted_media_data.get())) { + LOG(ERROR) << "Cannot decrypt samples"; + return false; + } + buffer->TransferData(std::move(decrypted_media_data), media_data_size); } + } else { + buffer->SetData(media_data, media_data_size); } } else { std::string id, settings, content; diff --git a/packager/media/formats/wvm/wvm_media_parser.cc b/packager/media/formats/wvm/wvm_media_parser.cc index eee6fcd5f7..089d7dd515 100644 --- a/packager/media/formats/wvm/wvm_media_parser.cc +++ b/packager/media/formats/wvm/wvm_media_parser.cc @@ -816,7 +816,7 @@ void WvmMediaParser::StartMediaSampleDemux() { bool WvmMediaParser::Output(bool output_encrypted_sample) { if (output_encrypted_sample) { - media_sample_->set_data(sample_data_.data(), sample_data_.size()); + media_sample_->SetData(sample_data_.data(), sample_data_.size()); media_sample_->set_is_encrypted(true); } else { if ((prev_pes_stream_id_ & kPesStreamIdVideoMask) == kPesStreamIdVideo) { @@ -827,7 +827,7 @@ bool WvmMediaParser::Output(bool output_encrypted_sample) { LOG(ERROR) << "Could not convert h.264 byte stream sample"; return false; } - media_sample_->set_data(nal_unit_stream.data(), nal_unit_stream.size()); + media_sample_->SetData(nal_unit_stream.data(), nal_unit_stream.size()); if (!is_initialized_) { // Set extra data for video stream from AVC Decoder Config Record. // Also, set codec string from the AVC Decoder Config Record. @@ -914,8 +914,7 @@ bool WvmMediaParser::Output(bool output_encrypted_sample) { } size_t header_size = adts_header.GetAdtsHeaderSize(frame_ptr, frame_size); - media_sample_->set_data(frame_ptr + header_size, - frame_size - header_size); + media_sample_->SetData(frame_ptr + header_size, frame_size - header_size); if (!is_initialized_) { for (uint32_t i = 0; i < stream_infos_.size(); i++) { if (stream_infos_[i]->stream_type() == kStreamAudio &&