Update MediaSample to avoid redundent copying

Use std::shared_ptr<const uint8_t> with a custom deleter to
represent MediaSample::data_ instead of std::vector<uint8_t>.

MediaSample::data_ can be shared by multiple MediaSamples and it is
immutable. A new data instance must be created if the clients want to
modify the underlying data. The new data instance can be transferred
to MediaSample using provided MediaSample::TransferData function.
This avoids unnecessary data copying.

Change-Id: Ib59785a9e19d0abb3283179b12eb6779ee922f79
This commit is contained in:
KongQun Yang 2017-09-18 16:31:00 -07:00
parent 6c0f2bebef
commit 92e1e39868
14 changed files with 402 additions and 282 deletions

View File

@ -10,6 +10,18 @@
#include "packager/media/base/aes_decryptor.h"
#include "packager/media/base/aes_pattern_cryptor.h"
namespace {
// Return true if [encrypted_buffer, encrypted_buffer + buffer_size) overlaps
// with [decrypted_buffer, decrypted_buffer + buffer_size).
bool CheckMemoryOverlap(const uint8_t* encrypted_buffer,
size_t buffer_size,
uint8_t* decrypted_buffer) {
return (decrypted_buffer < encrypted_buffer)
? (encrypted_buffer < decrypted_buffer + buffer_size)
: (decrypted_buffer < encrypted_buffer + buffer_size);
}
} // namespace
namespace shaka {
namespace media {
@ -20,10 +32,17 @@ DecryptorSource::DecryptorSource(KeySource* key_source)
DecryptorSource::~DecryptorSource() {}
bool DecryptorSource::DecryptSampleBuffer(const DecryptConfig* decrypt_config,
uint8_t* buffer,
size_t buffer_size) {
const uint8_t* encrypted_buffer,
size_t buffer_size,
uint8_t* decrypted_buffer) {
DCHECK(decrypt_config);
DCHECK(buffer);
DCHECK(encrypted_buffer);
DCHECK(decrypted_buffer);
if (CheckMemoryOverlap(encrypted_buffer, buffer_size, decrypted_buffer)) {
LOG(ERROR) << "Encrypted buffer and decrypted buffer cannot overlap.";
return false;
}
// Get the decryptor object.
AesCryptor* decryptor = nullptr;
@ -83,7 +102,7 @@ bool DecryptorSource::DecryptSampleBuffer(const DecryptConfig* decrypt_config,
if (decrypt_config->subsamples().empty()) {
// Sample not encrypted using subsample encryption. Decrypt whole.
if (!decryptor->Crypt(buffer, buffer_size, buffer)) {
if (!decryptor->Crypt(encrypted_buffer, buffer_size, decrypted_buffer)) {
LOG(ERROR) << "Error during bulk sample decryption.";
return false;
}
@ -92,20 +111,24 @@ bool DecryptorSource::DecryptSampleBuffer(const DecryptConfig* decrypt_config,
// Subsample decryption.
const std::vector<SubsampleEntry>& subsamples = decrypt_config->subsamples();
uint8_t* current_ptr = buffer;
const uint8_t* const buffer_end = buffer + buffer_size;
const uint8_t* current_ptr = encrypted_buffer;
const uint8_t* const buffer_end = encrypted_buffer + buffer_size;
for (const auto& subsample : subsamples) {
if ((current_ptr + subsample.clear_bytes + subsample.cipher_bytes) >
buffer_end) {
LOG(ERROR) << "Subsamples overflow sample buffer.";
return false;
}
memcpy(decrypted_buffer, current_ptr, subsample.clear_bytes);
current_ptr += subsample.clear_bytes;
if (!decryptor->Crypt(current_ptr, subsample.cipher_bytes, current_ptr)) {
decrypted_buffer += subsample.clear_bytes;
if (!decryptor->Crypt(current_ptr, subsample.cipher_bytes,
decrypted_buffer)) {
LOG(ERROR) << "Error decrypting subsample buffer.";
return false;
}
current_ptr += subsample.cipher_bytes;
decrypted_buffer += subsample.cipher_bytes;
}
return true;
}

View File

@ -21,12 +21,24 @@ namespace media {
/// DecryptorSource wraps KeySource and is responsible for decryptor management.
class DecryptorSource {
public:
/// Constructs a DecryptorSource object.
/// @param key_source points to the key source that contains the keys.
explicit DecryptorSource(KeySource* key_source);
~DecryptorSource();
/// Decrypt encrypted buffer.
/// @param decrypt_config contains decrypt configuration, e.g. protection
/// scheme, subsample information etc.
/// @param encrypted_buffer points to the encrypted buffer that is to be
/// decrypted. It should not overlap with @a decrypted_buffer.
/// @param buffer_size is the size of encrypted buffer and decrypted buffer.
/// @param decrypted_buffer points to the decrypted buffer. It should not
/// overlap with @a encrypted_buffer.
/// @return true if success, false otherwise.
bool DecryptSampleBuffer(const DecryptConfig* decrypt_config,
uint8_t* buffer,
size_t buffer_size);
const uint8_t* encrypted_buffer,
size_t buffer_size,
uint8_t* decrypted_buffer);
private:
KeySource* key_source_;

View File

@ -64,13 +64,15 @@ class DecryptorSourceTest : public ::testing::Test {
DecryptorSourceTest()
: decryptor_source_(&mock_key_source_),
key_id_(std::vector<uint8_t>(kKeyId, kKeyId + arraysize(kKeyId))),
buffer_(std::vector<uint8_t>(kBuffer, kBuffer + arraysize(kBuffer))) {}
encrypted_buffer_(kBuffer, kBuffer + arraysize(kBuffer)),
decrypted_buffer_(arraysize(kBuffer)) {}
protected:
StrictMock<MockKeySource> mock_key_source_;
DecryptorSource decryptor_source_;
std::vector<uint8_t> key_id_;
std::vector<uint8_t> buffer_;
std::vector<uint8_t> encrypted_buffer_;
std::vector<uint8_t> decrypted_buffer_;
};
TEST_F(DecryptorSourceTest, FullSampleDecryption) {
@ -83,24 +85,27 @@ TEST_F(DecryptorSourceTest, FullSampleDecryption) {
std::vector<uint8_t>(kIv, kIv + arraysize(kIv)),
std::vector<SubsampleEntry>());
ASSERT_TRUE(decryptor_source_.DecryptSampleBuffer(
&decrypt_config, &buffer_[0], buffer_.size()));
&decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(),
&decrypted_buffer_[0]));
EXPECT_EQ(std::vector<uint8_t>(
kExpectedDecryptedBuffer,
kExpectedDecryptedBuffer + arraysize(kExpectedDecryptedBuffer)),
buffer_);
decrypted_buffer_);
// DecryptSampleBuffer can be called repetitively. No GetKey call again with
// the same key id.
buffer_.assign(kBuffer2, kBuffer2 + arraysize(kBuffer2));
encrypted_buffer_.assign(kBuffer2, kBuffer2 + arraysize(kBuffer2));
decrypted_buffer_.resize(arraysize(kBuffer2));
DecryptConfig decrypt_config2(
key_id_, std::vector<uint8_t>(kIv2, kIv2 + arraysize(kIv2)),
std::vector<SubsampleEntry>());
ASSERT_TRUE(decryptor_source_.DecryptSampleBuffer(
&decrypt_config2, &buffer_[0], buffer_.size()));
&decrypt_config2, &encrypted_buffer_[0], encrypted_buffer_.size(),
&decrypted_buffer_[0]));
EXPECT_EQ(std::vector<uint8_t>(kExpectedDecryptedBuffer2,
kExpectedDecryptedBuffer2 +
arraysize(kExpectedDecryptedBuffer2)),
buffer_);
decrypted_buffer_);
}
TEST_F(DecryptorSourceTest, SubsampleDecryption) {
@ -131,11 +136,12 @@ TEST_F(DecryptorSourceTest, SubsampleDecryption) {
std::vector<SubsampleEntry>(kSubsamples,
kSubsamples + arraysize(kSubsamples)));
ASSERT_TRUE(decryptor_source_.DecryptSampleBuffer(
&decrypt_config, &buffer_[0], buffer_.size()));
&decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(),
&decrypted_buffer_[0]));
EXPECT_EQ(std::vector<uint8_t>(
kExpectedDecryptedBuffer,
kExpectedDecryptedBuffer + arraysize(kExpectedDecryptedBuffer)),
buffer_);
decrypted_buffer_);
}
TEST_F(DecryptorSourceTest, SubsampleDecryptionSizeValidation) {
@ -155,7 +161,8 @@ TEST_F(DecryptorSourceTest, SubsampleDecryptionSizeValidation) {
std::vector<SubsampleEntry>(kSubsamples,
kSubsamples + arraysize(kSubsamples)));
ASSERT_FALSE(decryptor_source_.DecryptSampleBuffer(
&decrypt_config, &buffer_[0], buffer_.size()));
&decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(),
&decrypted_buffer_[0]));
}
TEST_F(DecryptorSourceTest, DecryptFailedIfGetKeyFailed) {
@ -166,7 +173,17 @@ TEST_F(DecryptorSourceTest, DecryptFailedIfGetKeyFailed) {
std::vector<uint8_t>(kIv, kIv + arraysize(kIv)),
std::vector<SubsampleEntry>());
ASSERT_FALSE(decryptor_source_.DecryptSampleBuffer(
&decrypt_config, &buffer_[0], buffer_.size()));
&decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(),
&decrypted_buffer_[0]));
}
TEST_F(DecryptorSourceTest, EncryptedBufferAndDecryptedBufferOverlap) {
DecryptConfig decrypt_config(key_id_,
std::vector<uint8_t>(kIv, kIv + arraysize(kIv)),
std::vector<SubsampleEntry>());
ASSERT_FALSE(decryptor_source_.DecryptSampleBuffer(
&decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(),
&encrypted_buffer_[5]));
}
} // namespace media

View File

@ -154,21 +154,21 @@ std::unique_ptr<StreamInfo> MediaHandlerTestBase::GetAudioStreamInfo(
!kEncrypted));
}
std::unique_ptr<MediaSample> MediaHandlerTestBase::GetMediaSample(
std::shared_ptr<MediaSample> MediaHandlerTestBase::GetMediaSample(
int64_t timestamp,
int64_t duration,
bool is_keyframe) const {
return GetMediaSample(timestamp, duration, is_keyframe, kData, sizeof(kData));
}
std::unique_ptr<MediaSample> MediaHandlerTestBase::GetMediaSample(
std::shared_ptr<MediaSample> MediaHandlerTestBase::GetMediaSample(
int64_t timestamp,
int64_t duration,
bool is_keyframe,
const uint8_t* data,
size_t data_length) const {
std::unique_ptr<MediaSample> sample(
new MediaSample(data, data_length, nullptr, 0, is_keyframe));
std::shared_ptr<MediaSample> sample =
MediaSample::CopyFrom(data, data_length, nullptr, 0, is_keyframe);
sample->set_dts(timestamp);
sample->set_duration(duration);

View File

@ -150,11 +150,11 @@ class MediaHandlerTestBase : public ::testing::Test {
std::unique_ptr<StreamInfo> GetAudioStreamInfo(uint32_t time_scale,
Codec codec) const;
std::unique_ptr<MediaSample> GetMediaSample(int64_t timestamp,
std::shared_ptr<MediaSample> GetMediaSample(int64_t timestamp,
int64_t duration,
bool is_keyframe) const;
std::unique_ptr<MediaSample> GetMediaSample(int64_t timestamp,
std::shared_ptr<MediaSample> GetMediaSample(int64_t timestamp,
int64_t duration,
bool is_keyframe,
const uint8_t* data,

View File

@ -15,29 +15,26 @@ namespace shaka {
namespace media {
MediaSample::MediaSample(const uint8_t* data,
size_t size,
size_t data_size,
const uint8_t* side_data,
size_t side_data_size,
bool is_key_frame)
: dts_(0),
pts_(0),
duration_(0),
is_key_frame_(is_key_frame),
is_encrypted_(false) {
: is_key_frame_(is_key_frame) {
if (!data) {
CHECK_EQ(size, 0u);
CHECK_EQ(data_size, 0u);
}
data_.assign(data, data + size);
if (side_data)
side_data_.assign(side_data, side_data + side_data_size);
SetData(data, data_size);
if (side_data) {
std::shared_ptr<uint8_t> shared_side_data(new uint8_t[side_data_size],
std::default_delete<uint8_t[]>());
memcpy(shared_side_data.get(), side_data, side_data_size);
side_data_ = std::move(shared_side_data);
side_data_size_ = side_data_size;
}
}
MediaSample::MediaSample() : dts_(0),
pts_(0),
duration_(0),
is_key_frame_(false),
is_encrypted_(false) {}
MediaSample::MediaSample() {}
MediaSample::~MediaSample() {}
@ -47,8 +44,8 @@ std::shared_ptr<MediaSample> MediaSample::CopyFrom(const uint8_t* data,
bool is_key_frame) {
// If you hit this CHECK you likely have a bug in a demuxer. Go fix it.
CHECK(data);
return std::make_shared<MediaSample>(data, data_size, nullptr, 0u,
is_key_frame);
return std::shared_ptr<MediaSample>(
new MediaSample(data, data_size, nullptr, 0u, is_key_frame));
}
// static
@ -59,32 +56,32 @@ std::shared_ptr<MediaSample> MediaSample::CopyFrom(const uint8_t* data,
bool is_key_frame) {
// If you hit this CHECK you likely have a bug in a demuxer. Go fix it.
CHECK(data);
return std::make_shared<MediaSample>(data, data_size, side_data,
side_data_size, is_key_frame);
return std::shared_ptr<MediaSample>(new MediaSample(
data, data_size, side_data, side_data_size, is_key_frame));
}
// static
std::shared_ptr<MediaSample> MediaSample::CopyFrom(
const MediaSample& media_sample) {
std::shared_ptr<MediaSample> new_media_sample = CopyFrom(
media_sample.data(), media_sample.data_size(), media_sample.side_data(),
media_sample.side_data_size(), media_sample.is_key_frame());
new_media_sample->set_dts(media_sample.dts());
new_media_sample->set_pts(media_sample.pts());
new_media_sample->set_is_encrypted(media_sample.is_encrypted());
new_media_sample->set_config_id(media_sample.config_id());
new_media_sample->set_duration(media_sample.duration());
if (media_sample.decrypt_config()) {
std::unique_ptr<DecryptConfig> decrypt_config(
new DecryptConfig(media_sample.decrypt_config()->key_id(),
media_sample.decrypt_config()->iv(),
media_sample.decrypt_config()->subsamples(),
media_sample.decrypt_config()->protection_scheme(),
media_sample.decrypt_config()->crypt_byte_block(),
media_sample.decrypt_config()->skip_byte_block()));
new_media_sample->set_decrypt_config(std::move(decrypt_config));
std::shared_ptr<MediaSample> new_media_sample(new MediaSample);
new_media_sample->dts_ = media_sample.dts_;
new_media_sample->pts_ = media_sample.pts_;
new_media_sample->duration_ = media_sample.duration_;
new_media_sample->is_key_frame_ = media_sample.is_key_frame_;
new_media_sample->is_encrypted_ = media_sample.is_encrypted_;
new_media_sample->data_ = media_sample.data_;
new_media_sample->data_size_ = media_sample.data_size_;
new_media_sample->side_data_ = media_sample.side_data_;
new_media_sample->side_data_size_ = media_sample.side_data_size_;
new_media_sample->config_id_ = media_sample.config_id_;
if (media_sample.decrypt_config_) {
new_media_sample->decrypt_config_.reset(
new DecryptConfig(media_sample.decrypt_config_->key_id(),
media_sample.decrypt_config_->iv(),
media_sample.decrypt_config_->subsamples(),
media_sample.decrypt_config_->protection_scheme(),
media_sample.decrypt_config_->crypt_byte_block(),
media_sample.decrypt_config_->skip_byte_block()));
}
return new_media_sample;
}
@ -92,32 +89,43 @@ std::shared_ptr<MediaSample> MediaSample::CopyFrom(
// static
std::shared_ptr<MediaSample> MediaSample::FromMetadata(const uint8_t* metadata,
size_t metadata_size) {
return std::make_shared<MediaSample>(nullptr, 0, metadata, metadata_size,
false);
return std::shared_ptr<MediaSample>(
new MediaSample(nullptr, 0, metadata, metadata_size, false));
}
// static
std::shared_ptr<MediaSample> MediaSample::CreateEmptyMediaSample() {
return std::make_shared<MediaSample>();
return std::shared_ptr<MediaSample>(new MediaSample);
}
// static
std::shared_ptr<MediaSample> MediaSample::CreateEOSBuffer() {
return std::make_shared<MediaSample>(nullptr, 0, nullptr, 0, false);
return std::shared_ptr<MediaSample>(
new MediaSample(nullptr, 0, nullptr, 0, false));
}
void MediaSample::TransferData(std::shared_ptr<uint8_t> data,
size_t data_size) {
data_ = std::move(data);
data_size_ = data_size;
}
void MediaSample::SetData(const uint8_t* data, size_t data_size) {
std::shared_ptr<uint8_t> shared_data(new uint8_t[data_size],
std::default_delete<uint8_t[]>());
memcpy(shared_data.get(), data, data_size);
TransferData(std::move(shared_data), data_size);
}
std::string MediaSample::ToString() const {
if (end_of_stream())
return "End of stream sample\n";
return base::StringPrintf(
"dts: %" PRId64 "\n pts: %" PRId64 "\n duration: %" PRId64 "\n "
"dts: %" PRId64 "\n pts: %" PRId64 "\n duration: %" PRId64
"\n "
"is_key_frame: %s\n size: %zu\n side_data_size: %zu\n",
dts_,
pts_,
duration_,
is_key_frame_ ? "true" : "false",
data_.size(),
side_data_.size());
dts_, pts_, duration_, is_key_frame_ ? "true" : "false", data_size_,
side_data_size_);
}
} // namespace media

View File

@ -66,17 +66,22 @@ class MediaSample {
/// is disallowed.
static std::shared_ptr<MediaSample> CreateEOSBuffer();
// Create a MediaSample. Buffer will be padded and aligned as necessary.
// |data|,|side_data| can be NULL, which indicates an empty sample.
// |size|,|side_data_size| should not be negative.
MediaSample(const uint8_t* data,
size_t size,
const uint8_t* side_data,
size_t side_data_size,
bool is_key_frame);
MediaSample();
virtual ~MediaSample();
/// Transfer data to this media sample. No data copying is involved.
/// @param data points to the data to be transferred.
/// @param data_size is the size of the data to be transferred.
void TransferData(std::shared_ptr<uint8_t> data, size_t data_size);
/// Set the data in this media sample. Note that this method involves data
/// copying.
/// @param data points to the data to be copied.
/// @param data_size is the size of the data to be copied.
void SetData(const uint8_t* data, size_t data_size);
/// @return a human-readable string describing |*this|.
std::string ToString() const;
int64_t dts() const {
DCHECK(!end_of_stream());
return dts_;
@ -112,38 +117,19 @@ class MediaSample {
}
const uint8_t* data() const {
DCHECK(!end_of_stream());
return data_.data();
}
uint8_t* writable_data() {
DCHECK(!end_of_stream());
return data_.data();
return data_.get();
}
size_t data_size() const {
DCHECK(!end_of_stream());
return data_.size();
return data_size_;
}
const uint8_t* side_data() const {
return side_data_.data();
}
const uint8_t* side_data() const { return side_data_.get(); }
size_t side_data_size() const {
return side_data_.size();
}
size_t side_data_size() const { return side_data_size_; }
const DecryptConfig* decrypt_config() const {
return decrypt_config_.get();
}
void set_data(const uint8_t* data, const size_t data_size) {
data_.assign(data, data + data_size);
}
void resize_data(const size_t data_size) {
data_.resize(data_size);
}
const DecryptConfig* decrypt_config() const { return decrypt_config_.get(); }
void set_is_key_frame(bool value) {
is_key_frame_ = value;
@ -158,32 +144,42 @@ class MediaSample {
}
// If there's no data in this buffer, it represents end of stream.
bool end_of_stream() const { return data_.size() == 0; }
bool end_of_stream() const { return data_size_ == 0; }
const std::string& config_id() const { return config_id_; }
void set_config_id(const std::string& config_id) {
config_id_ = config_id;
}
/// @return a human-readable string describing |*this|.
std::string ToString() const;
protected:
// Made it protected to disallow the constructor to be called directly.
// Create a MediaSample. Buffer will be padded and aligned as necessary.
// |data|,|side_data| can be nullptr, which indicates an empty sample.
MediaSample(const uint8_t* data,
size_t data_size,
const uint8_t* side_data,
size_t side_data_size,
bool is_key_frame);
MediaSample();
private:
// Decoding time stamp.
int64_t dts_;
int64_t dts_ = 0;
// Presentation time stamp.
int64_t pts_;
int64_t duration_;
bool is_key_frame_;
int64_t pts_ = 0;
int64_t duration_ = 0;
bool is_key_frame_ = false;
// is sample encrypted ?
bool is_encrypted_;
bool is_encrypted_ = false;
// Main buffer data.
std::vector<uint8_t> data_;
std::shared_ptr<const uint8_t> data_;
size_t data_size_ = 0;
// Contain additional buffers to complete the main one. Needed by WebM
// http://www.matroska.org/technical/specs/index.html BlockAdditional[A5].
// Not used by mp4 and other containers.
std::vector<uint8_t> side_data_;
std::shared_ptr<const uint8_t> side_data_;
size_t side_data_size_ = 0;
// Text specific fields.
// For now this is the cue identifier for WebVTT.

View File

@ -9,6 +9,7 @@
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <limits>
#include "packager/media/base/aes_encryptor.h"
@ -259,49 +260,47 @@ Status EncryptionHandler::ProcessMediaSample(
// in-place.
std::shared_ptr<MediaSample> cipher_sample =
MediaSample::CopyFrom(*clear_sample);
// |cipher_sample| above still contains the old clear sample data. We will
// use |cipher_sample_data| to hold cipher sample data then transfer it to
// |cipher_sample| after encryption.
std::shared_ptr<uint8_t> cipher_sample_data(
new uint8_t[clear_sample->data_size()], std::default_delete<uint8_t[]>());
Status result;
if (vpx_parser_) {
if (EncryptVpxFrame(vpx_frames,
cipher_sample->writable_data(),
cipher_sample->data_size(),
decrypt_config.get())) {
DCHECK_EQ(decrypt_config->GetTotalSizeOfSubsamples(),
cipher_sample->data_size());
} else {
result = Status(
error::ENCRYPTION_FAILURE,
"Failed to encrypt VPX frame.");
if (!EncryptVpxFrame(vpx_frames, clear_sample->data(),
clear_sample->data_size(),
&cipher_sample_data.get()[0], decrypt_config.get())) {
return Status(error::ENCRYPTION_FAILURE, "Failed to encrypt VPX frame.");
}
DCHECK_EQ(decrypt_config->GetTotalSizeOfSubsamples(),
clear_sample->data_size());
} else if (header_parser_) {
if (EncryptNalFrame(cipher_sample->writable_data(),
cipher_sample->data_size(),
decrypt_config.get())) {
DCHECK_EQ(decrypt_config->GetTotalSizeOfSubsamples(),
cipher_sample->data_size());
} else {
result = Status(
error::ENCRYPTION_FAILURE,
"Failed to encrypt NAL frame.");
if (!EncryptNalFrame(clear_sample->data(), clear_sample->data_size(),
&cipher_sample_data.get()[0], decrypt_config.get())) {
return Status(error::ENCRYPTION_FAILURE, "Failed to encrypt NAL frame.");
}
DCHECK_EQ(decrypt_config->GetTotalSizeOfSubsamples(),
clear_sample->data_size());
} else {
memcpy(&cipher_sample_data.get()[0], clear_sample->data(),
std::min(clear_sample->data_size(), leading_clear_bytes_size_));
if (clear_sample->data_size() > leading_clear_bytes_size_) {
EncryptBytes(clear_sample->data() + leading_clear_bytes_size_,
clear_sample->data_size() - leading_clear_bytes_size_,
&cipher_sample_data.get()[leading_clear_bytes_size_]);
}
} else if (cipher_sample->data_size() > leading_clear_bytes_size_) {
EncryptBytes(
cipher_sample->writable_data() + leading_clear_bytes_size_,
cipher_sample->data_size() - leading_clear_bytes_size_);
}
if (!result.ok()) {
return result;
}
encryptor_->UpdateIv();
cipher_sample->TransferData(std::move(cipher_sample_data),
clear_sample->data_size());
// Finish initializing the sample before sending it downstream. We must
// wait until now to finish the initialization as we will loose access to
// wait until now to finish the initialization as we will lose access to
// |decrypt_config| once we set it.
cipher_sample->set_is_encrypted(true);
cipher_sample->set_decrypt_config(std::move(decrypt_config));
encryptor_->UpdateIv();
return DispatchMediaSample(kStreamIndex, std::move(cipher_sample));
}
@ -433,10 +432,11 @@ bool EncryptionHandler::CreateEncryptor(const EncryptionKey& encryption_key) {
bool EncryptionHandler::EncryptVpxFrame(
const std::vector<VPxFrameInfo>& vpx_frames,
uint8_t* source,
const uint8_t* source,
size_t source_size,
uint8_t* dest,
DecryptConfig* decrypt_config) {
uint8_t* data = source;
const uint8_t* data = source;
for (const VPxFrameInfo& frame : vpx_frames) {
uint16_t clear_bytes =
static_cast<uint16_t>(frame.uncompressed_header_size);
@ -456,9 +456,11 @@ bool EncryptionHandler::EncryptVpxFrame(
cipher_bytes -= misalign_bytes;
decrypt_config->AddSubsample(clear_bytes, cipher_bytes);
memcpy(dest, data, clear_bytes);
if (cipher_bytes > 0)
EncryptBytes(data + clear_bytes, cipher_bytes);
EncryptBytes(data + clear_bytes, cipher_bytes, dest + clear_bytes);
data += frame.frame_size;
dest += frame.frame_size;
}
// Add subsample for the superframe index if exists.
const bool is_superframe = vpx_frames.size() > 1;
@ -469,18 +471,20 @@ bool EncryptionHandler::EncryptVpxFrame(
uint16_t clear_bytes = static_cast<uint16_t>(index_size);
uint32_t cipher_bytes = 0;
decrypt_config->AddSubsample(clear_bytes, cipher_bytes);
memcpy(dest, data, clear_bytes);
}
return true;
}
bool EncryptionHandler::EncryptNalFrame(uint8_t* data,
size_t data_length,
bool EncryptionHandler::EncryptNalFrame(const uint8_t* source,
size_t source_size,
uint8_t* dest,
DecryptConfig* decrypt_config) {
DCHECK_NE(nalu_length_size_, 0u);
DCHECK(header_parser_);
const Nalu::CodecType nalu_type =
(codec_ == kCodecH265) ? Nalu::kH265 : Nalu::kH264;
NaluReader reader(nalu_type, nalu_length_size_, data, data_length);
NaluReader reader(nalu_type, nalu_length_size_, source, source_size);
// Store the current length of clear data. This is used to squash
// multiple unencrypted NAL units into fewer subsample entries.
@ -519,13 +523,17 @@ bool EncryptionHandler::EncryptNalFrame(uint8_t* data,
cipher_bytes -= misalign_bytes;
}
const uint8_t* nalu_data = nalu.data() + current_clear_bytes;
EncryptBytes(const_cast<uint8_t*>(nalu_data), cipher_bytes);
AddSubsample(
accumulated_clear_bytes + nalu_length_size_ + current_clear_bytes,
cipher_bytes, decrypt_config);
accumulated_clear_bytes += nalu_length_size_ + current_clear_bytes;
AddSubsample(accumulated_clear_bytes, cipher_bytes, decrypt_config);
memcpy(dest, source, accumulated_clear_bytes);
source += accumulated_clear_bytes;
dest += accumulated_clear_bytes;
accumulated_clear_bytes = 0;
DCHECK_EQ(nalu.data() + current_clear_bytes, source);
EncryptBytes(source, cipher_bytes, dest);
source += cipher_bytes;
dest += cipher_bytes;
} else {
// For non-video-slice or small NAL units, don't encrypt.
accumulated_clear_bytes += nalu_length_size_ + nalu_total_size;
@ -536,13 +544,17 @@ bool EncryptionHandler::EncryptNalFrame(uint8_t* data,
return false;
}
AddSubsample(accumulated_clear_bytes, 0, decrypt_config);
memcpy(dest, source, accumulated_clear_bytes);
return true;
}
void EncryptionHandler::EncryptBytes(uint8_t* data, size_t size) {
DCHECK(data);
void EncryptionHandler::EncryptBytes(const uint8_t* source,
size_t source_size,
uint8_t* dest) {
DCHECK(source);
DCHECK(dest);
DCHECK(encryptor_);
CHECK(encryptor_->Crypt(data, size, data));
CHECK(encryptor_->Crypt(source, source_size, dest));
}
void EncryptionHandler::InjectVpxParserForTesting(

View File

@ -47,15 +47,22 @@ class EncryptionHandler : public MediaHandler {
Status SetupProtectionPattern(StreamType stream_type);
bool CreateEncryptor(const EncryptionKey& encryption_key);
// Encrypt a VPx frame with size |source_size|. |dest| should have at least
// |source_size| bytes.
bool EncryptVpxFrame(const std::vector<VPxFrameInfo>& vpx_frames,
uint8_t* source,
const uint8_t* source,
size_t source_size,
uint8_t* dest,
DecryptConfig* decrypt_config);
bool EncryptNalFrame(uint8_t* data,
size_t data_length,
// Encrypt a NAL unit frame with size |source_size|. |dest| should have at
// least |source_size| bytes.
bool EncryptNalFrame(const uint8_t* source,
size_t source_size,
uint8_t* dest,
DecryptConfig* decrypt_config);
void EncryptBytes(uint8_t* data,
size_t size);
// Encrypt an array with size |source_size|. |dest| should have at
// least |source_size| bytes.
void EncryptBytes(const uint8_t* source, size_t source_size, uint8_t* dest);
// Testing injections.
void InjectVpxParserForTesting(std::unique_ptr<VPxParser> vpx_parser);

View File

@ -178,6 +178,10 @@ const uint8_t kData[]{
// Third non-video-slice NALU for H264 or superframe index for VP9.
0x06, 0x67, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
};
const size_t kDataSize = sizeof(kData);
// A short data size (less than leading clear bytes) for SampleAes audio
// testing.
const size_t kShortDataSize = 14;
// H264 subsample information for the the above data.
const size_t kNaluLengthSize = 1u;
@ -277,7 +281,7 @@ class EncryptionHandlerEncryptionTest
case kCodecVP9:
if (vp9_subsample_encryption_) {
std::unique_ptr<MockVpxParser> mock_vpx_parser(new MockVpxParser);
EXPECT_CALL(*mock_vpx_parser, Parse(_, sizeof(kData), _))
EXPECT_CALL(*mock_vpx_parser, Parse(_, kDataSize, _))
.WillRepeatedly(
DoAll(SetArgPointee<2>(GetMockVpxFrameInfo()), Return(true)));
InjectVpxParserForTesting(std::move(mock_vpx_parser));
@ -496,13 +500,8 @@ TEST_P(EncryptionHandlerEncryptionTest, ClearLeadWithNoKeyRotation) {
for (int i = 0; i < 3; ++i) {
// Use single-frame segment for testing.
ASSERT_OK(Process(StreamData::FromMediaSample(
kStreamIndex,
GetMediaSample(
i * kSegmentDuration,
kSegmentDuration,
kIsKeyFrame,
kData,
sizeof(kData)))));
kStreamIndex, GetMediaSample(i * kSegmentDuration, kSegmentDuration,
kIsKeyFrame, kData, kDataSize))));
ASSERT_OK(Process(StreamData::FromSegmentInfo(
kStreamIndex,
GetSegmentInfo(i * kSegmentDuration, kSegmentDuration, !kIsSubsegment))));
@ -568,13 +567,8 @@ TEST_P(EncryptionHandlerEncryptionTest, ClearLeadWithKeyRotation) {
}
// Use single-frame segment for testing.
ASSERT_OK(Process(StreamData::FromMediaSample(
kStreamIndex,
GetMediaSample(
i * kSegmentDuration,
kSegmentDuration,
kIsKeyFrame,
kData,
sizeof(kData)))));
kStreamIndex, GetMediaSample(i * kSegmentDuration, kSegmentDuration,
kIsKeyFrame, kData, kDataSize))));
ASSERT_OK(Process(StreamData::FromSegmentInfo(
kStreamIndex,
GetSegmentInfo(i * kSegmentDuration, kSegmentDuration, !kIsSubsegment))));
@ -625,20 +619,9 @@ TEST_P(EncryptionHandlerEncryptionTest, Encrypt) {
InjectCodecParser();
std::unique_ptr<StreamData> stream_data(new StreamData);
stream_data->stream_index = 0;
stream_data->stream_data_type = StreamDataType::kMediaSample;
stream_data->media_sample.reset(
new MediaSample(kData, sizeof(kData), nullptr, 0, kIsKeyFrame));
ASSERT_OK(Process(StreamData::FromMediaSample(
kStreamIndex,
GetMediaSample(
0,
kSampleDuration,
kIsKeyFrame,
kData,
sizeof(kData)))));
GetMediaSample(0, kSampleDuration, kIsKeyFrame, kData, kDataSize))));
ASSERT_EQ(2u, GetOutputStreamDataVector().size());
ASSERT_EQ(kStreamIndex, GetOutputStreamDataVector().back()->stream_index);
ASSERT_EQ(StreamDataType::kMediaSample,
@ -654,16 +637,52 @@ TEST_P(EncryptionHandlerEncryptionTest, Encrypt) {
EXPECT_EQ(GetExpectedCryptByteBlock(), decrypt_config->crypt_byte_block());
EXPECT_EQ(GetExpectedSkipByteBlock(), decrypt_config->skip_byte_block());
std::vector<uint8_t> expected(
kData,
kData + sizeof(kData));
std::vector<uint8_t> actual(
media_sample->data(),
media_sample->data() + media_sample->data_size());
std::vector<uint8_t> expected(kData, kData + kDataSize);
std::vector<uint8_t> actual(media_sample->data(),
media_sample->data() + media_sample->data_size());
ASSERT_TRUE(Decrypt(*decrypt_config, actual.data(), actual.size()));
EXPECT_EQ(expected, actual);
}
// Verify that the data in short audio (less than leading clear bytes) is left
// unencrypted.
TEST_P(EncryptionHandlerEncryptionTest, SampleAesEncryptShortAudio) {
if (IsVideoCodec(codec_) ||
protection_scheme_ != kAppleSampleAesProtectionScheme) {
return;
}
EncryptionParams encryption_params;
encryption_params.protection_scheme = kAppleSampleAesProtectionScheme;
SetUpEncryptionHandler(encryption_params);
const EncryptionKey mock_encryption_key = GetMockEncryptionKey();
EXPECT_CALL(mock_key_source_, GetKey(_, _))
.WillOnce(
DoAll(SetArgPointee<1>(mock_encryption_key), Return(Status::OK)));
ASSERT_OK(Process(StreamData::FromStreamInfo(
kStreamIndex, GetAudioStreamInfo(kTimeScale, codec_))));
ASSERT_OK(Process(StreamData::FromMediaSample(
kStreamIndex,
GetMediaSample(0, kSampleDuration, kIsKeyFrame, kData, kShortDataSize))));
ASSERT_EQ(2u, GetOutputStreamDataVector().size());
ASSERT_EQ(kStreamIndex, GetOutputStreamDataVector().back()->stream_index);
ASSERT_EQ(StreamDataType::kMediaSample,
GetOutputStreamDataVector().back()->stream_data_type);
auto* media_sample = GetOutputStreamDataVector().back()->media_sample.get();
auto* decrypt_config = media_sample->decrypt_config();
EXPECT_TRUE(decrypt_config->subsamples().empty());
EXPECT_EQ(kAppleSampleAesProtectionScheme,
decrypt_config->protection_scheme());
std::vector<uint8_t> expected(kData, kData + kShortDataSize);
std::vector<uint8_t> actual(media_sample->data(),
media_sample->data() + media_sample->data_size());
EXPECT_EQ(expected, actual);
}
INSTANTIATE_TEST_CASE_P(
CencProtectionSchemes,
EncryptionHandlerEncryptionTest,

View File

@ -708,9 +708,17 @@ bool MP4MediaParser::EnqueueSample(bool* err) {
return false;
}
const uint8_t* media_data = buf;
const size_t media_data_size = runs_->sample_size();
// Use a dummy data size of 0 to avoid copying overhead.
// Actual media data is set later.
const size_t kDummyDataSize = 0;
std::shared_ptr<MediaSample> stream_sample(
MediaSample::CopyFrom(buf, runs_->sample_size(), runs_->is_keyframe()));
MediaSample::CopyFrom(media_data, kDummyDataSize, runs_->is_keyframe()));
if (runs_->is_encrypted()) {
std::shared_ptr<uint8_t> decrypted_media_data(
new uint8_t[media_data_size], std::default_delete<uint8_t[]>());
std::unique_ptr<DecryptConfig> decrypt_config = runs_->GetDecryptConfig();
if (!decrypt_config) {
*err = true;
@ -719,17 +727,24 @@ bool MP4MediaParser::EnqueueSample(bool* err) {
}
if (!decryptor_source_) {
stream_sample->SetData(media_data, media_data_size);
// If the demuxer does not have the decryptor_source_, store
// decrypt_config so that the demuxed sample can be decrypted later.
stream_sample->set_decrypt_config(std::move(decrypt_config));
stream_sample->set_is_encrypted(true);
} else if (!decryptor_source_->DecryptSampleBuffer(
decrypt_config.get(), stream_sample->writable_data(),
stream_sample->data_size())) {
*err = true;
LOG(ERROR) << "Cannot decrypt samples.";
return false;
} else {
if (!decryptor_source_->DecryptSampleBuffer(decrypt_config.get(),
media_data, media_data_size,
decrypted_media_data.get())) {
*err = true;
LOG(ERROR) << "Cannot decrypt samples.";
return false;
}
stream_sample->TransferData(std::move(decrypted_media_data),
media_data_size);
}
} else {
stream_sample->SetData(media_data, media_data_size);
}
stream_sample->set_dts(runs_->dts());

View File

@ -13,6 +13,57 @@
namespace shaka {
namespace media {
namespace webm {
namespace {
void WriteEncryptedFrameHeader(const DecryptConfig* decrypt_config,
BufferWriter* header_buffer) {
if (decrypt_config) {
const size_t iv_size = decrypt_config->iv().size();
DCHECK_EQ(iv_size, kWebMIvSize);
if (!decrypt_config->subsamples().empty()) {
const auto& subsamples = decrypt_config->subsamples();
// Use partitioned subsample encryption: | signal_byte(3) | iv
// | num_partitions | partition_offset * n | enc_data |
DCHECK_LT(subsamples.size(), kWebMMaxSubsamples);
const size_t num_partitions =
2 * subsamples.size() - 1 -
(subsamples.back().cipher_bytes == 0 ? 1 : 0);
const size_t header_size = kWebMSignalByteSize + iv_size +
kWebMNumPartitionsSize +
(kWebMPartitionOffsetSize * num_partitions);
const uint8_t signal_byte = kWebMEncryptedSignal | kWebMPartitionedSignal;
header_buffer->AppendInt(signal_byte);
header_buffer->AppendVector(decrypt_config->iv());
header_buffer->AppendInt(static_cast<uint8_t>(num_partitions));
uint32_t partition_offset = 0;
for (size_t i = 0; i < subsamples.size() - 1; ++i) {
partition_offset += subsamples[i].clear_bytes;
header_buffer->AppendInt(partition_offset);
partition_offset += subsamples[i].cipher_bytes;
header_buffer->AppendInt(partition_offset);
}
// Add another partition between the clear bytes and cipher bytes if
// cipher bytes is not zero.
if (subsamples.back().cipher_bytes != 0) {
partition_offset += subsamples.back().clear_bytes;
header_buffer->AppendInt(partition_offset);
}
DCHECK_EQ(header_size, header_buffer->Size());
} else {
// Use whole-frame encryption: | signal_byte(1) | iv | enc_data |
const uint8_t signal_byte = kWebMEncryptedSignal;
header_buffer->AppendInt(signal_byte);
header_buffer->AppendVector(decrypt_config->iv());
}
} else {
// Clear sample: | signal_byte(0) | data |
const uint8_t signal_byte = 0x00;
header_buffer->AppendInt(signal_byte);
}
}
} // namespace
Status UpdateTrackForEncryption(const std::vector<uint8_t>& key_id,
mkvmuxer::Track* track) {
@ -46,69 +97,17 @@ Status UpdateTrackForEncryption(const std::vector<uint8_t>& key_id,
}
void UpdateFrameForEncryption(MediaSample* sample) {
const size_t sample_size = sample->data_size();
if (sample->decrypt_config()) {
auto* decrypt_config = sample->decrypt_config();
const size_t iv_size = decrypt_config->iv().size();
DCHECK_EQ(iv_size, kWebMIvSize);
if (!decrypt_config->subsamples().empty()) {
auto& subsamples = decrypt_config->subsamples();
// Use partitioned subsample encryption: | signal_byte(3) | iv
// | num_partitions | partition_offset * n | enc_data |
DCHECK_LT(subsamples.size(), kWebMMaxSubsamples);
const size_t num_partitions =
2 * subsamples.size() - 1 -
(subsamples.back().cipher_bytes == 0 ? 1 : 0);
const size_t header_size = kWebMSignalByteSize + iv_size +
kWebMNumPartitionsSize +
(kWebMPartitionOffsetSize * num_partitions);
sample->resize_data(header_size + sample_size);
uint8_t* sample_data = sample->writable_data();
memmove(sample_data + header_size, sample_data, sample_size);
sample_data[0] = kWebMEncryptedSignal | kWebMPartitionedSignal;
memcpy(sample_data + kWebMSignalByteSize, decrypt_config->iv().data(),
iv_size);
sample_data[kWebMSignalByteSize + kWebMIvSize] =
static_cast<uint8_t>(num_partitions);
DCHECK(sample);
BufferWriter header_buffer;
WriteEncryptedFrameHeader(sample->decrypt_config(), &header_buffer);
BufferWriter offsets_buffer;
uint32_t partition_offset = 0;
for (size_t i = 0; i < subsamples.size() - 1; ++i) {
partition_offset += subsamples[i].clear_bytes;
offsets_buffer.AppendInt(partition_offset);
partition_offset += subsamples[i].cipher_bytes;
offsets_buffer.AppendInt(partition_offset);
}
// Add another partition between the clear bytes and cipher bytes if
// cipher bytes is not zero.
if (subsamples.back().cipher_bytes != 0) {
partition_offset += subsamples.back().clear_bytes;
offsets_buffer.AppendInt(partition_offset);
}
DCHECK_EQ(num_partitions * kWebMPartitionOffsetSize,
offsets_buffer.Size());
memcpy(sample_data + kWebMSignalByteSize + kWebMIvSize +
kWebMNumPartitionsSize,
offsets_buffer.Buffer(), offsets_buffer.Size());
} else {
// Use whole-frame encryption: | signal_byte(1) | iv | enc_data |
sample->resize_data(sample_size + iv_size + kWebMSignalByteSize);
uint8_t* sample_data = sample->writable_data();
// First move the sample data to after the IV; then write the IV and
// signal byte.
memmove(sample_data + iv_size + kWebMSignalByteSize, sample_data,
sample_size);
sample_data[0] = kWebMEncryptedSignal;
memcpy(sample_data + 1, decrypt_config->iv().data(), iv_size);
}
} else {
// Clear sample: | signal_byte(0) | data |
sample->resize_data(sample_size + 1);
uint8_t* sample_data = sample->writable_data();
memmove(sample_data + 1, sample_data, sample_size);
sample_data[0] = 0x00;
}
const size_t sample_size = header_buffer.Size() + sample->data_size();
std::shared_ptr<uint8_t> new_sample_data(new uint8_t[sample_size],
std::default_delete<uint8_t[]>());
memcpy(new_sample_data.get(), header_buffer.Buffer(), header_buffer.Size());
memcpy(&new_sample_data.get()[header_buffer.Size()], sample->data(),
sample->data_size());
sample->TransferData(std::move(new_sample_data), sample_size);
}
} // namespace webm

View File

@ -376,21 +376,34 @@ bool WebMClusterParser::OnBlock(bool is_simple_block,
return false;
}
buffer = MediaSample::CopyFrom(data + data_offset, size - data_offset,
additional, additional_size, is_key_frame);
const uint8_t* media_data = data + data_offset;
const size_t media_data_size = size - data_offset;
// Use a dummy data size of 0 to avoid copying overhead.
// Actual media data is set later.
const size_t kDummyDataSize = 0;
buffer = MediaSample::CopyFrom(media_data, kDummyDataSize, additional,
additional_size, is_key_frame);
if (decrypt_config) {
if (!decryptor_source_) {
buffer->SetData(media_data, media_data_size);
// If the demuxer does not have the decryptor_source_, store
// decrypt_config so that the demuxed sample can be decrypted later.
buffer->set_decrypt_config(std::move(decrypt_config));
buffer->set_is_encrypted(true);
} else if (!decryptor_source_->DecryptSampleBuffer(
decrypt_config.get(), buffer->writable_data(),
buffer->data_size())) {
LOG(ERROR) << "Cannot decrypt samples";
return false;
} else {
std::shared_ptr<uint8_t> decrypted_media_data(
new uint8_t[media_data_size], std::default_delete<uint8_t[]>());
if (!decryptor_source_->DecryptSampleBuffer(
decrypt_config.get(), media_data, media_data_size,
decrypted_media_data.get())) {
LOG(ERROR) << "Cannot decrypt samples";
return false;
}
buffer->TransferData(std::move(decrypted_media_data), media_data_size);
}
} else {
buffer->SetData(media_data, media_data_size);
}
} else {
std::string id, settings, content;

View File

@ -816,7 +816,7 @@ void WvmMediaParser::StartMediaSampleDemux() {
bool WvmMediaParser::Output(bool output_encrypted_sample) {
if (output_encrypted_sample) {
media_sample_->set_data(sample_data_.data(), sample_data_.size());
media_sample_->SetData(sample_data_.data(), sample_data_.size());
media_sample_->set_is_encrypted(true);
} else {
if ((prev_pes_stream_id_ & kPesStreamIdVideoMask) == kPesStreamIdVideo) {
@ -827,7 +827,7 @@ bool WvmMediaParser::Output(bool output_encrypted_sample) {
LOG(ERROR) << "Could not convert h.264 byte stream sample";
return false;
}
media_sample_->set_data(nal_unit_stream.data(), nal_unit_stream.size());
media_sample_->SetData(nal_unit_stream.data(), nal_unit_stream.size());
if (!is_initialized_) {
// Set extra data for video stream from AVC Decoder Config Record.
// Also, set codec string from the AVC Decoder Config Record.
@ -914,8 +914,7 @@ bool WvmMediaParser::Output(bool output_encrypted_sample) {
}
size_t header_size = adts_header.GetAdtsHeaderSize(frame_ptr,
frame_size);
media_sample_->set_data(frame_ptr + header_size,
frame_size - header_size);
media_sample_->SetData(frame_ptr + header_size, frame_size - header_size);
if (!is_initialized_) {
for (uint32_t i = 0; i < stream_infos_.size(); i++) {
if (stream_infos_[i]->stream_type() == kStreamAudio &&