Update MediaSample to avoid redundent copying

Use std::shared_ptr<const uint8_t> with a custom deleter to
represent MediaSample::data_ instead of std::vector<uint8_t>.

MediaSample::data_ can be shared by multiple MediaSamples and it is
immutable. A new data instance must be created if the clients want to
modify the underlying data. The new data instance can be transferred
to MediaSample using provided MediaSample::TransferData function.
This avoids unnecessary data copying.

Change-Id: Ib59785a9e19d0abb3283179b12eb6779ee922f79
This commit is contained in:
KongQun Yang 2017-09-18 16:31:00 -07:00
parent 6c0f2bebef
commit 92e1e39868
14 changed files with 402 additions and 282 deletions

View File

@ -10,6 +10,18 @@
#include "packager/media/base/aes_decryptor.h" #include "packager/media/base/aes_decryptor.h"
#include "packager/media/base/aes_pattern_cryptor.h" #include "packager/media/base/aes_pattern_cryptor.h"
namespace {
// Return true if [encrypted_buffer, encrypted_buffer + buffer_size) overlaps
// with [decrypted_buffer, decrypted_buffer + buffer_size).
bool CheckMemoryOverlap(const uint8_t* encrypted_buffer,
size_t buffer_size,
uint8_t* decrypted_buffer) {
return (decrypted_buffer < encrypted_buffer)
? (encrypted_buffer < decrypted_buffer + buffer_size)
: (decrypted_buffer < encrypted_buffer + buffer_size);
}
} // namespace
namespace shaka { namespace shaka {
namespace media { namespace media {
@ -20,10 +32,17 @@ DecryptorSource::DecryptorSource(KeySource* key_source)
DecryptorSource::~DecryptorSource() {} DecryptorSource::~DecryptorSource() {}
bool DecryptorSource::DecryptSampleBuffer(const DecryptConfig* decrypt_config, bool DecryptorSource::DecryptSampleBuffer(const DecryptConfig* decrypt_config,
uint8_t* buffer, const uint8_t* encrypted_buffer,
size_t buffer_size) { size_t buffer_size,
uint8_t* decrypted_buffer) {
DCHECK(decrypt_config); DCHECK(decrypt_config);
DCHECK(buffer); DCHECK(encrypted_buffer);
DCHECK(decrypted_buffer);
if (CheckMemoryOverlap(encrypted_buffer, buffer_size, decrypted_buffer)) {
LOG(ERROR) << "Encrypted buffer and decrypted buffer cannot overlap.";
return false;
}
// Get the decryptor object. // Get the decryptor object.
AesCryptor* decryptor = nullptr; AesCryptor* decryptor = nullptr;
@ -83,7 +102,7 @@ bool DecryptorSource::DecryptSampleBuffer(const DecryptConfig* decrypt_config,
if (decrypt_config->subsamples().empty()) { if (decrypt_config->subsamples().empty()) {
// Sample not encrypted using subsample encryption. Decrypt whole. // Sample not encrypted using subsample encryption. Decrypt whole.
if (!decryptor->Crypt(buffer, buffer_size, buffer)) { if (!decryptor->Crypt(encrypted_buffer, buffer_size, decrypted_buffer)) {
LOG(ERROR) << "Error during bulk sample decryption."; LOG(ERROR) << "Error during bulk sample decryption.";
return false; return false;
} }
@ -92,20 +111,24 @@ bool DecryptorSource::DecryptSampleBuffer(const DecryptConfig* decrypt_config,
// Subsample decryption. // Subsample decryption.
const std::vector<SubsampleEntry>& subsamples = decrypt_config->subsamples(); const std::vector<SubsampleEntry>& subsamples = decrypt_config->subsamples();
uint8_t* current_ptr = buffer; const uint8_t* current_ptr = encrypted_buffer;
const uint8_t* const buffer_end = buffer + buffer_size; const uint8_t* const buffer_end = encrypted_buffer + buffer_size;
for (const auto& subsample : subsamples) { for (const auto& subsample : subsamples) {
if ((current_ptr + subsample.clear_bytes + subsample.cipher_bytes) > if ((current_ptr + subsample.clear_bytes + subsample.cipher_bytes) >
buffer_end) { buffer_end) {
LOG(ERROR) << "Subsamples overflow sample buffer."; LOG(ERROR) << "Subsamples overflow sample buffer.";
return false; return false;
} }
memcpy(decrypted_buffer, current_ptr, subsample.clear_bytes);
current_ptr += subsample.clear_bytes; current_ptr += subsample.clear_bytes;
if (!decryptor->Crypt(current_ptr, subsample.cipher_bytes, current_ptr)) { decrypted_buffer += subsample.clear_bytes;
if (!decryptor->Crypt(current_ptr, subsample.cipher_bytes,
decrypted_buffer)) {
LOG(ERROR) << "Error decrypting subsample buffer."; LOG(ERROR) << "Error decrypting subsample buffer.";
return false; return false;
} }
current_ptr += subsample.cipher_bytes; current_ptr += subsample.cipher_bytes;
decrypted_buffer += subsample.cipher_bytes;
} }
return true; return true;
} }

View File

@ -21,12 +21,24 @@ namespace media {
/// DecryptorSource wraps KeySource and is responsible for decryptor management. /// DecryptorSource wraps KeySource and is responsible for decryptor management.
class DecryptorSource { class DecryptorSource {
public: public:
/// Constructs a DecryptorSource object.
/// @param key_source points to the key source that contains the keys.
explicit DecryptorSource(KeySource* key_source); explicit DecryptorSource(KeySource* key_source);
~DecryptorSource(); ~DecryptorSource();
/// Decrypt encrypted buffer.
/// @param decrypt_config contains decrypt configuration, e.g. protection
/// scheme, subsample information etc.
/// @param encrypted_buffer points to the encrypted buffer that is to be
/// decrypted. It should not overlap with @a decrypted_buffer.
/// @param buffer_size is the size of encrypted buffer and decrypted buffer.
/// @param decrypted_buffer points to the decrypted buffer. It should not
/// overlap with @a encrypted_buffer.
/// @return true if success, false otherwise.
bool DecryptSampleBuffer(const DecryptConfig* decrypt_config, bool DecryptSampleBuffer(const DecryptConfig* decrypt_config,
uint8_t* buffer, const uint8_t* encrypted_buffer,
size_t buffer_size); size_t buffer_size,
uint8_t* decrypted_buffer);
private: private:
KeySource* key_source_; KeySource* key_source_;

View File

@ -64,13 +64,15 @@ class DecryptorSourceTest : public ::testing::Test {
DecryptorSourceTest() DecryptorSourceTest()
: decryptor_source_(&mock_key_source_), : decryptor_source_(&mock_key_source_),
key_id_(std::vector<uint8_t>(kKeyId, kKeyId + arraysize(kKeyId))), key_id_(std::vector<uint8_t>(kKeyId, kKeyId + arraysize(kKeyId))),
buffer_(std::vector<uint8_t>(kBuffer, kBuffer + arraysize(kBuffer))) {} encrypted_buffer_(kBuffer, kBuffer + arraysize(kBuffer)),
decrypted_buffer_(arraysize(kBuffer)) {}
protected: protected:
StrictMock<MockKeySource> mock_key_source_; StrictMock<MockKeySource> mock_key_source_;
DecryptorSource decryptor_source_; DecryptorSource decryptor_source_;
std::vector<uint8_t> key_id_; std::vector<uint8_t> key_id_;
std::vector<uint8_t> buffer_; std::vector<uint8_t> encrypted_buffer_;
std::vector<uint8_t> decrypted_buffer_;
}; };
TEST_F(DecryptorSourceTest, FullSampleDecryption) { TEST_F(DecryptorSourceTest, FullSampleDecryption) {
@ -83,24 +85,27 @@ TEST_F(DecryptorSourceTest, FullSampleDecryption) {
std::vector<uint8_t>(kIv, kIv + arraysize(kIv)), std::vector<uint8_t>(kIv, kIv + arraysize(kIv)),
std::vector<SubsampleEntry>()); std::vector<SubsampleEntry>());
ASSERT_TRUE(decryptor_source_.DecryptSampleBuffer( ASSERT_TRUE(decryptor_source_.DecryptSampleBuffer(
&decrypt_config, &buffer_[0], buffer_.size())); &decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(),
&decrypted_buffer_[0]));
EXPECT_EQ(std::vector<uint8_t>( EXPECT_EQ(std::vector<uint8_t>(
kExpectedDecryptedBuffer, kExpectedDecryptedBuffer,
kExpectedDecryptedBuffer + arraysize(kExpectedDecryptedBuffer)), kExpectedDecryptedBuffer + arraysize(kExpectedDecryptedBuffer)),
buffer_); decrypted_buffer_);
// DecryptSampleBuffer can be called repetitively. No GetKey call again with // DecryptSampleBuffer can be called repetitively. No GetKey call again with
// the same key id. // the same key id.
buffer_.assign(kBuffer2, kBuffer2 + arraysize(kBuffer2)); encrypted_buffer_.assign(kBuffer2, kBuffer2 + arraysize(kBuffer2));
decrypted_buffer_.resize(arraysize(kBuffer2));
DecryptConfig decrypt_config2( DecryptConfig decrypt_config2(
key_id_, std::vector<uint8_t>(kIv2, kIv2 + arraysize(kIv2)), key_id_, std::vector<uint8_t>(kIv2, kIv2 + arraysize(kIv2)),
std::vector<SubsampleEntry>()); std::vector<SubsampleEntry>());
ASSERT_TRUE(decryptor_source_.DecryptSampleBuffer( ASSERT_TRUE(decryptor_source_.DecryptSampleBuffer(
&decrypt_config2, &buffer_[0], buffer_.size())); &decrypt_config2, &encrypted_buffer_[0], encrypted_buffer_.size(),
&decrypted_buffer_[0]));
EXPECT_EQ(std::vector<uint8_t>(kExpectedDecryptedBuffer2, EXPECT_EQ(std::vector<uint8_t>(kExpectedDecryptedBuffer2,
kExpectedDecryptedBuffer2 + kExpectedDecryptedBuffer2 +
arraysize(kExpectedDecryptedBuffer2)), arraysize(kExpectedDecryptedBuffer2)),
buffer_); decrypted_buffer_);
} }
TEST_F(DecryptorSourceTest, SubsampleDecryption) { TEST_F(DecryptorSourceTest, SubsampleDecryption) {
@ -131,11 +136,12 @@ TEST_F(DecryptorSourceTest, SubsampleDecryption) {
std::vector<SubsampleEntry>(kSubsamples, std::vector<SubsampleEntry>(kSubsamples,
kSubsamples + arraysize(kSubsamples))); kSubsamples + arraysize(kSubsamples)));
ASSERT_TRUE(decryptor_source_.DecryptSampleBuffer( ASSERT_TRUE(decryptor_source_.DecryptSampleBuffer(
&decrypt_config, &buffer_[0], buffer_.size())); &decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(),
&decrypted_buffer_[0]));
EXPECT_EQ(std::vector<uint8_t>( EXPECT_EQ(std::vector<uint8_t>(
kExpectedDecryptedBuffer, kExpectedDecryptedBuffer,
kExpectedDecryptedBuffer + arraysize(kExpectedDecryptedBuffer)), kExpectedDecryptedBuffer + arraysize(kExpectedDecryptedBuffer)),
buffer_); decrypted_buffer_);
} }
TEST_F(DecryptorSourceTest, SubsampleDecryptionSizeValidation) { TEST_F(DecryptorSourceTest, SubsampleDecryptionSizeValidation) {
@ -155,7 +161,8 @@ TEST_F(DecryptorSourceTest, SubsampleDecryptionSizeValidation) {
std::vector<SubsampleEntry>(kSubsamples, std::vector<SubsampleEntry>(kSubsamples,
kSubsamples + arraysize(kSubsamples))); kSubsamples + arraysize(kSubsamples)));
ASSERT_FALSE(decryptor_source_.DecryptSampleBuffer( ASSERT_FALSE(decryptor_source_.DecryptSampleBuffer(
&decrypt_config, &buffer_[0], buffer_.size())); &decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(),
&decrypted_buffer_[0]));
} }
TEST_F(DecryptorSourceTest, DecryptFailedIfGetKeyFailed) { TEST_F(DecryptorSourceTest, DecryptFailedIfGetKeyFailed) {
@ -166,7 +173,17 @@ TEST_F(DecryptorSourceTest, DecryptFailedIfGetKeyFailed) {
std::vector<uint8_t>(kIv, kIv + arraysize(kIv)), std::vector<uint8_t>(kIv, kIv + arraysize(kIv)),
std::vector<SubsampleEntry>()); std::vector<SubsampleEntry>());
ASSERT_FALSE(decryptor_source_.DecryptSampleBuffer( ASSERT_FALSE(decryptor_source_.DecryptSampleBuffer(
&decrypt_config, &buffer_[0], buffer_.size())); &decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(),
&decrypted_buffer_[0]));
}
TEST_F(DecryptorSourceTest, EncryptedBufferAndDecryptedBufferOverlap) {
DecryptConfig decrypt_config(key_id_,
std::vector<uint8_t>(kIv, kIv + arraysize(kIv)),
std::vector<SubsampleEntry>());
ASSERT_FALSE(decryptor_source_.DecryptSampleBuffer(
&decrypt_config, &encrypted_buffer_[0], encrypted_buffer_.size(),
&encrypted_buffer_[5]));
} }
} // namespace media } // namespace media

View File

@ -154,21 +154,21 @@ std::unique_ptr<StreamInfo> MediaHandlerTestBase::GetAudioStreamInfo(
!kEncrypted)); !kEncrypted));
} }
std::unique_ptr<MediaSample> MediaHandlerTestBase::GetMediaSample( std::shared_ptr<MediaSample> MediaHandlerTestBase::GetMediaSample(
int64_t timestamp, int64_t timestamp,
int64_t duration, int64_t duration,
bool is_keyframe) const { bool is_keyframe) const {
return GetMediaSample(timestamp, duration, is_keyframe, kData, sizeof(kData)); return GetMediaSample(timestamp, duration, is_keyframe, kData, sizeof(kData));
} }
std::unique_ptr<MediaSample> MediaHandlerTestBase::GetMediaSample( std::shared_ptr<MediaSample> MediaHandlerTestBase::GetMediaSample(
int64_t timestamp, int64_t timestamp,
int64_t duration, int64_t duration,
bool is_keyframe, bool is_keyframe,
const uint8_t* data, const uint8_t* data,
size_t data_length) const { size_t data_length) const {
std::unique_ptr<MediaSample> sample( std::shared_ptr<MediaSample> sample =
new MediaSample(data, data_length, nullptr, 0, is_keyframe)); MediaSample::CopyFrom(data, data_length, nullptr, 0, is_keyframe);
sample->set_dts(timestamp); sample->set_dts(timestamp);
sample->set_duration(duration); sample->set_duration(duration);

View File

@ -150,11 +150,11 @@ class MediaHandlerTestBase : public ::testing::Test {
std::unique_ptr<StreamInfo> GetAudioStreamInfo(uint32_t time_scale, std::unique_ptr<StreamInfo> GetAudioStreamInfo(uint32_t time_scale,
Codec codec) const; Codec codec) const;
std::unique_ptr<MediaSample> GetMediaSample(int64_t timestamp, std::shared_ptr<MediaSample> GetMediaSample(int64_t timestamp,
int64_t duration, int64_t duration,
bool is_keyframe) const; bool is_keyframe) const;
std::unique_ptr<MediaSample> GetMediaSample(int64_t timestamp, std::shared_ptr<MediaSample> GetMediaSample(int64_t timestamp,
int64_t duration, int64_t duration,
bool is_keyframe, bool is_keyframe,
const uint8_t* data, const uint8_t* data,

View File

@ -15,29 +15,26 @@ namespace shaka {
namespace media { namespace media {
MediaSample::MediaSample(const uint8_t* data, MediaSample::MediaSample(const uint8_t* data,
size_t size, size_t data_size,
const uint8_t* side_data, const uint8_t* side_data,
size_t side_data_size, size_t side_data_size,
bool is_key_frame) bool is_key_frame)
: dts_(0), : is_key_frame_(is_key_frame) {
pts_(0),
duration_(0),
is_key_frame_(is_key_frame),
is_encrypted_(false) {
if (!data) { if (!data) {
CHECK_EQ(size, 0u); CHECK_EQ(data_size, 0u);
} }
data_.assign(data, data + size); SetData(data, data_size);
if (side_data) if (side_data) {
side_data_.assign(side_data, side_data + side_data_size); std::shared_ptr<uint8_t> shared_side_data(new uint8_t[side_data_size],
std::default_delete<uint8_t[]>());
memcpy(shared_side_data.get(), side_data, side_data_size);
side_data_ = std::move(shared_side_data);
side_data_size_ = side_data_size;
}
} }
MediaSample::MediaSample() : dts_(0), MediaSample::MediaSample() {}
pts_(0),
duration_(0),
is_key_frame_(false),
is_encrypted_(false) {}
MediaSample::~MediaSample() {} MediaSample::~MediaSample() {}
@ -47,8 +44,8 @@ std::shared_ptr<MediaSample> MediaSample::CopyFrom(const uint8_t* data,
bool is_key_frame) { bool is_key_frame) {
// If you hit this CHECK you likely have a bug in a demuxer. Go fix it. // If you hit this CHECK you likely have a bug in a demuxer. Go fix it.
CHECK(data); CHECK(data);
return std::make_shared<MediaSample>(data, data_size, nullptr, 0u, return std::shared_ptr<MediaSample>(
is_key_frame); new MediaSample(data, data_size, nullptr, 0u, is_key_frame));
} }
// static // static
@ -59,32 +56,32 @@ std::shared_ptr<MediaSample> MediaSample::CopyFrom(const uint8_t* data,
bool is_key_frame) { bool is_key_frame) {
// If you hit this CHECK you likely have a bug in a demuxer. Go fix it. // If you hit this CHECK you likely have a bug in a demuxer. Go fix it.
CHECK(data); CHECK(data);
return std::make_shared<MediaSample>(data, data_size, side_data, return std::shared_ptr<MediaSample>(new MediaSample(
side_data_size, is_key_frame); data, data_size, side_data, side_data_size, is_key_frame));
} }
// static // static
std::shared_ptr<MediaSample> MediaSample::CopyFrom( std::shared_ptr<MediaSample> MediaSample::CopyFrom(
const MediaSample& media_sample) { const MediaSample& media_sample) {
std::shared_ptr<MediaSample> new_media_sample = CopyFrom( std::shared_ptr<MediaSample> new_media_sample(new MediaSample);
media_sample.data(), media_sample.data_size(), media_sample.side_data(), new_media_sample->dts_ = media_sample.dts_;
media_sample.side_data_size(), media_sample.is_key_frame()); new_media_sample->pts_ = media_sample.pts_;
new_media_sample->duration_ = media_sample.duration_;
new_media_sample->set_dts(media_sample.dts()); new_media_sample->is_key_frame_ = media_sample.is_key_frame_;
new_media_sample->set_pts(media_sample.pts()); new_media_sample->is_encrypted_ = media_sample.is_encrypted_;
new_media_sample->set_is_encrypted(media_sample.is_encrypted()); new_media_sample->data_ = media_sample.data_;
new_media_sample->set_config_id(media_sample.config_id()); new_media_sample->data_size_ = media_sample.data_size_;
new_media_sample->set_duration(media_sample.duration()); new_media_sample->side_data_ = media_sample.side_data_;
new_media_sample->side_data_size_ = media_sample.side_data_size_;
if (media_sample.decrypt_config()) { new_media_sample->config_id_ = media_sample.config_id_;
std::unique_ptr<DecryptConfig> decrypt_config( if (media_sample.decrypt_config_) {
new DecryptConfig(media_sample.decrypt_config()->key_id(), new_media_sample->decrypt_config_.reset(
media_sample.decrypt_config()->iv(), new DecryptConfig(media_sample.decrypt_config_->key_id(),
media_sample.decrypt_config()->subsamples(), media_sample.decrypt_config_->iv(),
media_sample.decrypt_config()->protection_scheme(), media_sample.decrypt_config_->subsamples(),
media_sample.decrypt_config()->crypt_byte_block(), media_sample.decrypt_config_->protection_scheme(),
media_sample.decrypt_config()->skip_byte_block())); media_sample.decrypt_config_->crypt_byte_block(),
new_media_sample->set_decrypt_config(std::move(decrypt_config)); media_sample.decrypt_config_->skip_byte_block()));
} }
return new_media_sample; return new_media_sample;
} }
@ -92,32 +89,43 @@ std::shared_ptr<MediaSample> MediaSample::CopyFrom(
// static // static
std::shared_ptr<MediaSample> MediaSample::FromMetadata(const uint8_t* metadata, std::shared_ptr<MediaSample> MediaSample::FromMetadata(const uint8_t* metadata,
size_t metadata_size) { size_t metadata_size) {
return std::make_shared<MediaSample>(nullptr, 0, metadata, metadata_size, return std::shared_ptr<MediaSample>(
false); new MediaSample(nullptr, 0, metadata, metadata_size, false));
} }
// static // static
std::shared_ptr<MediaSample> MediaSample::CreateEmptyMediaSample() { std::shared_ptr<MediaSample> MediaSample::CreateEmptyMediaSample() {
return std::make_shared<MediaSample>(); return std::shared_ptr<MediaSample>(new MediaSample);
} }
// static // static
std::shared_ptr<MediaSample> MediaSample::CreateEOSBuffer() { std::shared_ptr<MediaSample> MediaSample::CreateEOSBuffer() {
return std::make_shared<MediaSample>(nullptr, 0, nullptr, 0, false); return std::shared_ptr<MediaSample>(
new MediaSample(nullptr, 0, nullptr, 0, false));
}
void MediaSample::TransferData(std::shared_ptr<uint8_t> data,
size_t data_size) {
data_ = std::move(data);
data_size_ = data_size;
}
void MediaSample::SetData(const uint8_t* data, size_t data_size) {
std::shared_ptr<uint8_t> shared_data(new uint8_t[data_size],
std::default_delete<uint8_t[]>());
memcpy(shared_data.get(), data, data_size);
TransferData(std::move(shared_data), data_size);
} }
std::string MediaSample::ToString() const { std::string MediaSample::ToString() const {
if (end_of_stream()) if (end_of_stream())
return "End of stream sample\n"; return "End of stream sample\n";
return base::StringPrintf( return base::StringPrintf(
"dts: %" PRId64 "\n pts: %" PRId64 "\n duration: %" PRId64 "\n " "dts: %" PRId64 "\n pts: %" PRId64 "\n duration: %" PRId64
"\n "
"is_key_frame: %s\n size: %zu\n side_data_size: %zu\n", "is_key_frame: %s\n size: %zu\n side_data_size: %zu\n",
dts_, dts_, pts_, duration_, is_key_frame_ ? "true" : "false", data_size_,
pts_, side_data_size_);
duration_,
is_key_frame_ ? "true" : "false",
data_.size(),
side_data_.size());
} }
} // namespace media } // namespace media

View File

@ -66,17 +66,22 @@ class MediaSample {
/// is disallowed. /// is disallowed.
static std::shared_ptr<MediaSample> CreateEOSBuffer(); static std::shared_ptr<MediaSample> CreateEOSBuffer();
// Create a MediaSample. Buffer will be padded and aligned as necessary.
// |data|,|side_data| can be NULL, which indicates an empty sample.
// |size|,|side_data_size| should not be negative.
MediaSample(const uint8_t* data,
size_t size,
const uint8_t* side_data,
size_t side_data_size,
bool is_key_frame);
MediaSample();
virtual ~MediaSample(); virtual ~MediaSample();
/// Transfer data to this media sample. No data copying is involved.
/// @param data points to the data to be transferred.
/// @param data_size is the size of the data to be transferred.
void TransferData(std::shared_ptr<uint8_t> data, size_t data_size);
/// Set the data in this media sample. Note that this method involves data
/// copying.
/// @param data points to the data to be copied.
/// @param data_size is the size of the data to be copied.
void SetData(const uint8_t* data, size_t data_size);
/// @return a human-readable string describing |*this|.
std::string ToString() const;
int64_t dts() const { int64_t dts() const {
DCHECK(!end_of_stream()); DCHECK(!end_of_stream());
return dts_; return dts_;
@ -112,38 +117,19 @@ class MediaSample {
} }
const uint8_t* data() const { const uint8_t* data() const {
DCHECK(!end_of_stream()); DCHECK(!end_of_stream());
return data_.data(); return data_.get();
}
uint8_t* writable_data() {
DCHECK(!end_of_stream());
return data_.data();
} }
size_t data_size() const { size_t data_size() const {
DCHECK(!end_of_stream()); DCHECK(!end_of_stream());
return data_.size(); return data_size_;
} }
const uint8_t* side_data() const { const uint8_t* side_data() const { return side_data_.get(); }
return side_data_.data();
}
size_t side_data_size() const { size_t side_data_size() const { return side_data_size_; }
return side_data_.size();
}
const DecryptConfig* decrypt_config() const { const DecryptConfig* decrypt_config() const { return decrypt_config_.get(); }
return decrypt_config_.get();
}
void set_data(const uint8_t* data, const size_t data_size) {
data_.assign(data, data + data_size);
}
void resize_data(const size_t data_size) {
data_.resize(data_size);
}
void set_is_key_frame(bool value) { void set_is_key_frame(bool value) {
is_key_frame_ = value; is_key_frame_ = value;
@ -158,32 +144,42 @@ class MediaSample {
} }
// If there's no data in this buffer, it represents end of stream. // If there's no data in this buffer, it represents end of stream.
bool end_of_stream() const { return data_.size() == 0; } bool end_of_stream() const { return data_size_ == 0; }
const std::string& config_id() const { return config_id_; } const std::string& config_id() const { return config_id_; }
void set_config_id(const std::string& config_id) { void set_config_id(const std::string& config_id) {
config_id_ = config_id; config_id_ = config_id;
} }
/// @return a human-readable string describing |*this|. protected:
std::string ToString() const; // Made it protected to disallow the constructor to be called directly.
// Create a MediaSample. Buffer will be padded and aligned as necessary.
// |data|,|side_data| can be nullptr, which indicates an empty sample.
MediaSample(const uint8_t* data,
size_t data_size,
const uint8_t* side_data,
size_t side_data_size,
bool is_key_frame);
MediaSample();
private: private:
// Decoding time stamp. // Decoding time stamp.
int64_t dts_; int64_t dts_ = 0;
// Presentation time stamp. // Presentation time stamp.
int64_t pts_; int64_t pts_ = 0;
int64_t duration_; int64_t duration_ = 0;
bool is_key_frame_; bool is_key_frame_ = false;
// is sample encrypted ? // is sample encrypted ?
bool is_encrypted_; bool is_encrypted_ = false;
// Main buffer data. // Main buffer data.
std::vector<uint8_t> data_; std::shared_ptr<const uint8_t> data_;
size_t data_size_ = 0;
// Contain additional buffers to complete the main one. Needed by WebM // Contain additional buffers to complete the main one. Needed by WebM
// http://www.matroska.org/technical/specs/index.html BlockAdditional[A5]. // http://www.matroska.org/technical/specs/index.html BlockAdditional[A5].
// Not used by mp4 and other containers. // Not used by mp4 and other containers.
std::vector<uint8_t> side_data_; std::shared_ptr<const uint8_t> side_data_;
size_t side_data_size_ = 0;
// Text specific fields. // Text specific fields.
// For now this is the cue identifier for WebVTT. // For now this is the cue identifier for WebVTT.

View File

@ -9,6 +9,7 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <algorithm>
#include <limits> #include <limits>
#include "packager/media/base/aes_encryptor.h" #include "packager/media/base/aes_encryptor.h"
@ -259,49 +260,47 @@ Status EncryptionHandler::ProcessMediaSample(
// in-place. // in-place.
std::shared_ptr<MediaSample> cipher_sample = std::shared_ptr<MediaSample> cipher_sample =
MediaSample::CopyFrom(*clear_sample); MediaSample::CopyFrom(*clear_sample);
// |cipher_sample| above still contains the old clear sample data. We will
// use |cipher_sample_data| to hold cipher sample data then transfer it to
// |cipher_sample| after encryption.
std::shared_ptr<uint8_t> cipher_sample_data(
new uint8_t[clear_sample->data_size()], std::default_delete<uint8_t[]>());
Status result;
if (vpx_parser_) { if (vpx_parser_) {
if (EncryptVpxFrame(vpx_frames, if (!EncryptVpxFrame(vpx_frames, clear_sample->data(),
cipher_sample->writable_data(), clear_sample->data_size(),
cipher_sample->data_size(), &cipher_sample_data.get()[0], decrypt_config.get())) {
decrypt_config.get())) { return Status(error::ENCRYPTION_FAILURE, "Failed to encrypt VPX frame.");
DCHECK_EQ(decrypt_config->GetTotalSizeOfSubsamples(),
cipher_sample->data_size());
} else {
result = Status(
error::ENCRYPTION_FAILURE,
"Failed to encrypt VPX frame.");
} }
DCHECK_EQ(decrypt_config->GetTotalSizeOfSubsamples(),
clear_sample->data_size());
} else if (header_parser_) { } else if (header_parser_) {
if (EncryptNalFrame(cipher_sample->writable_data(), if (!EncryptNalFrame(clear_sample->data(), clear_sample->data_size(),
cipher_sample->data_size(), &cipher_sample_data.get()[0], decrypt_config.get())) {
decrypt_config.get())) { return Status(error::ENCRYPTION_FAILURE, "Failed to encrypt NAL frame.");
}
DCHECK_EQ(decrypt_config->GetTotalSizeOfSubsamples(), DCHECK_EQ(decrypt_config->GetTotalSizeOfSubsamples(),
cipher_sample->data_size()); clear_sample->data_size());
} else { } else {
result = Status( memcpy(&cipher_sample_data.get()[0], clear_sample->data(),
error::ENCRYPTION_FAILURE, std::min(clear_sample->data_size(), leading_clear_bytes_size_));
"Failed to encrypt NAL frame."); if (clear_sample->data_size() > leading_clear_bytes_size_) {
EncryptBytes(clear_sample->data() + leading_clear_bytes_size_,
clear_sample->data_size() - leading_clear_bytes_size_,
&cipher_sample_data.get()[leading_clear_bytes_size_]);
} }
} else if (cipher_sample->data_size() > leading_clear_bytes_size_) {
EncryptBytes(
cipher_sample->writable_data() + leading_clear_bytes_size_,
cipher_sample->data_size() - leading_clear_bytes_size_);
} }
if (!result.ok()) { cipher_sample->TransferData(std::move(cipher_sample_data),
return result; clear_sample->data_size());
}
encryptor_->UpdateIv();
// Finish initializing the sample before sending it downstream. We must // Finish initializing the sample before sending it downstream. We must
// wait until now to finish the initialization as we will loose access to // wait until now to finish the initialization as we will lose access to
// |decrypt_config| once we set it. // |decrypt_config| once we set it.
cipher_sample->set_is_encrypted(true); cipher_sample->set_is_encrypted(true);
cipher_sample->set_decrypt_config(std::move(decrypt_config)); cipher_sample->set_decrypt_config(std::move(decrypt_config));
encryptor_->UpdateIv();
return DispatchMediaSample(kStreamIndex, std::move(cipher_sample)); return DispatchMediaSample(kStreamIndex, std::move(cipher_sample));
} }
@ -433,10 +432,11 @@ bool EncryptionHandler::CreateEncryptor(const EncryptionKey& encryption_key) {
bool EncryptionHandler::EncryptVpxFrame( bool EncryptionHandler::EncryptVpxFrame(
const std::vector<VPxFrameInfo>& vpx_frames, const std::vector<VPxFrameInfo>& vpx_frames,
uint8_t* source, const uint8_t* source,
size_t source_size, size_t source_size,
uint8_t* dest,
DecryptConfig* decrypt_config) { DecryptConfig* decrypt_config) {
uint8_t* data = source; const uint8_t* data = source;
for (const VPxFrameInfo& frame : vpx_frames) { for (const VPxFrameInfo& frame : vpx_frames) {
uint16_t clear_bytes = uint16_t clear_bytes =
static_cast<uint16_t>(frame.uncompressed_header_size); static_cast<uint16_t>(frame.uncompressed_header_size);
@ -456,9 +456,11 @@ bool EncryptionHandler::EncryptVpxFrame(
cipher_bytes -= misalign_bytes; cipher_bytes -= misalign_bytes;
decrypt_config->AddSubsample(clear_bytes, cipher_bytes); decrypt_config->AddSubsample(clear_bytes, cipher_bytes);
memcpy(dest, data, clear_bytes);
if (cipher_bytes > 0) if (cipher_bytes > 0)
EncryptBytes(data + clear_bytes, cipher_bytes); EncryptBytes(data + clear_bytes, cipher_bytes, dest + clear_bytes);
data += frame.frame_size; data += frame.frame_size;
dest += frame.frame_size;
} }
// Add subsample for the superframe index if exists. // Add subsample for the superframe index if exists.
const bool is_superframe = vpx_frames.size() > 1; const bool is_superframe = vpx_frames.size() > 1;
@ -469,18 +471,20 @@ bool EncryptionHandler::EncryptVpxFrame(
uint16_t clear_bytes = static_cast<uint16_t>(index_size); uint16_t clear_bytes = static_cast<uint16_t>(index_size);
uint32_t cipher_bytes = 0; uint32_t cipher_bytes = 0;
decrypt_config->AddSubsample(clear_bytes, cipher_bytes); decrypt_config->AddSubsample(clear_bytes, cipher_bytes);
memcpy(dest, data, clear_bytes);
} }
return true; return true;
} }
bool EncryptionHandler::EncryptNalFrame(uint8_t* data, bool EncryptionHandler::EncryptNalFrame(const uint8_t* source,
size_t data_length, size_t source_size,
uint8_t* dest,
DecryptConfig* decrypt_config) { DecryptConfig* decrypt_config) {
DCHECK_NE(nalu_length_size_, 0u); DCHECK_NE(nalu_length_size_, 0u);
DCHECK(header_parser_); DCHECK(header_parser_);
const Nalu::CodecType nalu_type = const Nalu::CodecType nalu_type =
(codec_ == kCodecH265) ? Nalu::kH265 : Nalu::kH264; (codec_ == kCodecH265) ? Nalu::kH265 : Nalu::kH264;
NaluReader reader(nalu_type, nalu_length_size_, data, data_length); NaluReader reader(nalu_type, nalu_length_size_, source, source_size);
// Store the current length of clear data. This is used to squash // Store the current length of clear data. This is used to squash
// multiple unencrypted NAL units into fewer subsample entries. // multiple unencrypted NAL units into fewer subsample entries.
@ -519,13 +523,17 @@ bool EncryptionHandler::EncryptNalFrame(uint8_t* data,
cipher_bytes -= misalign_bytes; cipher_bytes -= misalign_bytes;
} }
const uint8_t* nalu_data = nalu.data() + current_clear_bytes; accumulated_clear_bytes += nalu_length_size_ + current_clear_bytes;
EncryptBytes(const_cast<uint8_t*>(nalu_data), cipher_bytes); AddSubsample(accumulated_clear_bytes, cipher_bytes, decrypt_config);
memcpy(dest, source, accumulated_clear_bytes);
AddSubsample( source += accumulated_clear_bytes;
accumulated_clear_bytes + nalu_length_size_ + current_clear_bytes, dest += accumulated_clear_bytes;
cipher_bytes, decrypt_config);
accumulated_clear_bytes = 0; accumulated_clear_bytes = 0;
DCHECK_EQ(nalu.data() + current_clear_bytes, source);
EncryptBytes(source, cipher_bytes, dest);
source += cipher_bytes;
dest += cipher_bytes;
} else { } else {
// For non-video-slice or small NAL units, don't encrypt. // For non-video-slice or small NAL units, don't encrypt.
accumulated_clear_bytes += nalu_length_size_ + nalu_total_size; accumulated_clear_bytes += nalu_length_size_ + nalu_total_size;
@ -536,13 +544,17 @@ bool EncryptionHandler::EncryptNalFrame(uint8_t* data,
return false; return false;
} }
AddSubsample(accumulated_clear_bytes, 0, decrypt_config); AddSubsample(accumulated_clear_bytes, 0, decrypt_config);
memcpy(dest, source, accumulated_clear_bytes);
return true; return true;
} }
void EncryptionHandler::EncryptBytes(uint8_t* data, size_t size) { void EncryptionHandler::EncryptBytes(const uint8_t* source,
DCHECK(data); size_t source_size,
uint8_t* dest) {
DCHECK(source);
DCHECK(dest);
DCHECK(encryptor_); DCHECK(encryptor_);
CHECK(encryptor_->Crypt(data, size, data)); CHECK(encryptor_->Crypt(source, source_size, dest));
} }
void EncryptionHandler::InjectVpxParserForTesting( void EncryptionHandler::InjectVpxParserForTesting(

View File

@ -47,15 +47,22 @@ class EncryptionHandler : public MediaHandler {
Status SetupProtectionPattern(StreamType stream_type); Status SetupProtectionPattern(StreamType stream_type);
bool CreateEncryptor(const EncryptionKey& encryption_key); bool CreateEncryptor(const EncryptionKey& encryption_key);
// Encrypt a VPx frame with size |source_size|. |dest| should have at least
// |source_size| bytes.
bool EncryptVpxFrame(const std::vector<VPxFrameInfo>& vpx_frames, bool EncryptVpxFrame(const std::vector<VPxFrameInfo>& vpx_frames,
uint8_t* source, const uint8_t* source,
size_t source_size, size_t source_size,
uint8_t* dest,
DecryptConfig* decrypt_config); DecryptConfig* decrypt_config);
bool EncryptNalFrame(uint8_t* data, // Encrypt a NAL unit frame with size |source_size|. |dest| should have at
size_t data_length, // least |source_size| bytes.
bool EncryptNalFrame(const uint8_t* source,
size_t source_size,
uint8_t* dest,
DecryptConfig* decrypt_config); DecryptConfig* decrypt_config);
void EncryptBytes(uint8_t* data, // Encrypt an array with size |source_size|. |dest| should have at
size_t size); // least |source_size| bytes.
void EncryptBytes(const uint8_t* source, size_t source_size, uint8_t* dest);
// Testing injections. // Testing injections.
void InjectVpxParserForTesting(std::unique_ptr<VPxParser> vpx_parser); void InjectVpxParserForTesting(std::unique_ptr<VPxParser> vpx_parser);

View File

@ -178,6 +178,10 @@ const uint8_t kData[]{
// Third non-video-slice NALU for H264 or superframe index for VP9. // Third non-video-slice NALU for H264 or superframe index for VP9.
0x06, 0x67, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x06, 0x67, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e,
}; };
const size_t kDataSize = sizeof(kData);
// A short data size (less than leading clear bytes) for SampleAes audio
// testing.
const size_t kShortDataSize = 14;
// H264 subsample information for the the above data. // H264 subsample information for the the above data.
const size_t kNaluLengthSize = 1u; const size_t kNaluLengthSize = 1u;
@ -277,7 +281,7 @@ class EncryptionHandlerEncryptionTest
case kCodecVP9: case kCodecVP9:
if (vp9_subsample_encryption_) { if (vp9_subsample_encryption_) {
std::unique_ptr<MockVpxParser> mock_vpx_parser(new MockVpxParser); std::unique_ptr<MockVpxParser> mock_vpx_parser(new MockVpxParser);
EXPECT_CALL(*mock_vpx_parser, Parse(_, sizeof(kData), _)) EXPECT_CALL(*mock_vpx_parser, Parse(_, kDataSize, _))
.WillRepeatedly( .WillRepeatedly(
DoAll(SetArgPointee<2>(GetMockVpxFrameInfo()), Return(true))); DoAll(SetArgPointee<2>(GetMockVpxFrameInfo()), Return(true)));
InjectVpxParserForTesting(std::move(mock_vpx_parser)); InjectVpxParserForTesting(std::move(mock_vpx_parser));
@ -496,13 +500,8 @@ TEST_P(EncryptionHandlerEncryptionTest, ClearLeadWithNoKeyRotation) {
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
// Use single-frame segment for testing. // Use single-frame segment for testing.
ASSERT_OK(Process(StreamData::FromMediaSample( ASSERT_OK(Process(StreamData::FromMediaSample(
kStreamIndex, kStreamIndex, GetMediaSample(i * kSegmentDuration, kSegmentDuration,
GetMediaSample( kIsKeyFrame, kData, kDataSize))));
i * kSegmentDuration,
kSegmentDuration,
kIsKeyFrame,
kData,
sizeof(kData)))));
ASSERT_OK(Process(StreamData::FromSegmentInfo( ASSERT_OK(Process(StreamData::FromSegmentInfo(
kStreamIndex, kStreamIndex,
GetSegmentInfo(i * kSegmentDuration, kSegmentDuration, !kIsSubsegment)))); GetSegmentInfo(i * kSegmentDuration, kSegmentDuration, !kIsSubsegment))));
@ -568,13 +567,8 @@ TEST_P(EncryptionHandlerEncryptionTest, ClearLeadWithKeyRotation) {
} }
// Use single-frame segment for testing. // Use single-frame segment for testing.
ASSERT_OK(Process(StreamData::FromMediaSample( ASSERT_OK(Process(StreamData::FromMediaSample(
kStreamIndex, kStreamIndex, GetMediaSample(i * kSegmentDuration, kSegmentDuration,
GetMediaSample( kIsKeyFrame, kData, kDataSize))));
i * kSegmentDuration,
kSegmentDuration,
kIsKeyFrame,
kData,
sizeof(kData)))));
ASSERT_OK(Process(StreamData::FromSegmentInfo( ASSERT_OK(Process(StreamData::FromSegmentInfo(
kStreamIndex, kStreamIndex,
GetSegmentInfo(i * kSegmentDuration, kSegmentDuration, !kIsSubsegment)))); GetSegmentInfo(i * kSegmentDuration, kSegmentDuration, !kIsSubsegment))));
@ -625,20 +619,9 @@ TEST_P(EncryptionHandlerEncryptionTest, Encrypt) {
InjectCodecParser(); InjectCodecParser();
std::unique_ptr<StreamData> stream_data(new StreamData);
stream_data->stream_index = 0;
stream_data->stream_data_type = StreamDataType::kMediaSample;
stream_data->media_sample.reset(
new MediaSample(kData, sizeof(kData), nullptr, 0, kIsKeyFrame));
ASSERT_OK(Process(StreamData::FromMediaSample( ASSERT_OK(Process(StreamData::FromMediaSample(
kStreamIndex, kStreamIndex,
GetMediaSample( GetMediaSample(0, kSampleDuration, kIsKeyFrame, kData, kDataSize))));
0,
kSampleDuration,
kIsKeyFrame,
kData,
sizeof(kData)))));
ASSERT_EQ(2u, GetOutputStreamDataVector().size()); ASSERT_EQ(2u, GetOutputStreamDataVector().size());
ASSERT_EQ(kStreamIndex, GetOutputStreamDataVector().back()->stream_index); ASSERT_EQ(kStreamIndex, GetOutputStreamDataVector().back()->stream_index);
ASSERT_EQ(StreamDataType::kMediaSample, ASSERT_EQ(StreamDataType::kMediaSample,
@ -654,16 +637,52 @@ TEST_P(EncryptionHandlerEncryptionTest, Encrypt) {
EXPECT_EQ(GetExpectedCryptByteBlock(), decrypt_config->crypt_byte_block()); EXPECT_EQ(GetExpectedCryptByteBlock(), decrypt_config->crypt_byte_block());
EXPECT_EQ(GetExpectedSkipByteBlock(), decrypt_config->skip_byte_block()); EXPECT_EQ(GetExpectedSkipByteBlock(), decrypt_config->skip_byte_block());
std::vector<uint8_t> expected( std::vector<uint8_t> expected(kData, kData + kDataSize);
kData, std::vector<uint8_t> actual(media_sample->data(),
kData + sizeof(kData));
std::vector<uint8_t> actual(
media_sample->data(),
media_sample->data() + media_sample->data_size()); media_sample->data() + media_sample->data_size());
ASSERT_TRUE(Decrypt(*decrypt_config, actual.data(), actual.size())); ASSERT_TRUE(Decrypt(*decrypt_config, actual.data(), actual.size()));
EXPECT_EQ(expected, actual); EXPECT_EQ(expected, actual);
} }
// Verify that the data in short audio (less than leading clear bytes) is left
// unencrypted.
TEST_P(EncryptionHandlerEncryptionTest, SampleAesEncryptShortAudio) {
if (IsVideoCodec(codec_) ||
protection_scheme_ != kAppleSampleAesProtectionScheme) {
return;
}
EncryptionParams encryption_params;
encryption_params.protection_scheme = kAppleSampleAesProtectionScheme;
SetUpEncryptionHandler(encryption_params);
const EncryptionKey mock_encryption_key = GetMockEncryptionKey();
EXPECT_CALL(mock_key_source_, GetKey(_, _))
.WillOnce(
DoAll(SetArgPointee<1>(mock_encryption_key), Return(Status::OK)));
ASSERT_OK(Process(StreamData::FromStreamInfo(
kStreamIndex, GetAudioStreamInfo(kTimeScale, codec_))));
ASSERT_OK(Process(StreamData::FromMediaSample(
kStreamIndex,
GetMediaSample(0, kSampleDuration, kIsKeyFrame, kData, kShortDataSize))));
ASSERT_EQ(2u, GetOutputStreamDataVector().size());
ASSERT_EQ(kStreamIndex, GetOutputStreamDataVector().back()->stream_index);
ASSERT_EQ(StreamDataType::kMediaSample,
GetOutputStreamDataVector().back()->stream_data_type);
auto* media_sample = GetOutputStreamDataVector().back()->media_sample.get();
auto* decrypt_config = media_sample->decrypt_config();
EXPECT_TRUE(decrypt_config->subsamples().empty());
EXPECT_EQ(kAppleSampleAesProtectionScheme,
decrypt_config->protection_scheme());
std::vector<uint8_t> expected(kData, kData + kShortDataSize);
std::vector<uint8_t> actual(media_sample->data(),
media_sample->data() + media_sample->data_size());
EXPECT_EQ(expected, actual);
}
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
CencProtectionSchemes, CencProtectionSchemes,
EncryptionHandlerEncryptionTest, EncryptionHandlerEncryptionTest,

View File

@ -708,9 +708,17 @@ bool MP4MediaParser::EnqueueSample(bool* err) {
return false; return false;
} }
const uint8_t* media_data = buf;
const size_t media_data_size = runs_->sample_size();
// Use a dummy data size of 0 to avoid copying overhead.
// Actual media data is set later.
const size_t kDummyDataSize = 0;
std::shared_ptr<MediaSample> stream_sample( std::shared_ptr<MediaSample> stream_sample(
MediaSample::CopyFrom(buf, runs_->sample_size(), runs_->is_keyframe())); MediaSample::CopyFrom(media_data, kDummyDataSize, runs_->is_keyframe()));
if (runs_->is_encrypted()) { if (runs_->is_encrypted()) {
std::shared_ptr<uint8_t> decrypted_media_data(
new uint8_t[media_data_size], std::default_delete<uint8_t[]>());
std::unique_ptr<DecryptConfig> decrypt_config = runs_->GetDecryptConfig(); std::unique_ptr<DecryptConfig> decrypt_config = runs_->GetDecryptConfig();
if (!decrypt_config) { if (!decrypt_config) {
*err = true; *err = true;
@ -719,17 +727,24 @@ bool MP4MediaParser::EnqueueSample(bool* err) {
} }
if (!decryptor_source_) { if (!decryptor_source_) {
stream_sample->SetData(media_data, media_data_size);
// If the demuxer does not have the decryptor_source_, store // If the demuxer does not have the decryptor_source_, store
// decrypt_config so that the demuxed sample can be decrypted later. // decrypt_config so that the demuxed sample can be decrypted later.
stream_sample->set_decrypt_config(std::move(decrypt_config)); stream_sample->set_decrypt_config(std::move(decrypt_config));
stream_sample->set_is_encrypted(true); stream_sample->set_is_encrypted(true);
} else if (!decryptor_source_->DecryptSampleBuffer( } else {
decrypt_config.get(), stream_sample->writable_data(), if (!decryptor_source_->DecryptSampleBuffer(decrypt_config.get(),
stream_sample->data_size())) { media_data, media_data_size,
decrypted_media_data.get())) {
*err = true; *err = true;
LOG(ERROR) << "Cannot decrypt samples."; LOG(ERROR) << "Cannot decrypt samples.";
return false; return false;
} }
stream_sample->TransferData(std::move(decrypted_media_data),
media_data_size);
}
} else {
stream_sample->SetData(media_data, media_data_size);
} }
stream_sample->set_dts(runs_->dts()); stream_sample->set_dts(runs_->dts());

View File

@ -13,6 +13,57 @@
namespace shaka { namespace shaka {
namespace media { namespace media {
namespace webm { namespace webm {
namespace {
void WriteEncryptedFrameHeader(const DecryptConfig* decrypt_config,
BufferWriter* header_buffer) {
if (decrypt_config) {
const size_t iv_size = decrypt_config->iv().size();
DCHECK_EQ(iv_size, kWebMIvSize);
if (!decrypt_config->subsamples().empty()) {
const auto& subsamples = decrypt_config->subsamples();
// Use partitioned subsample encryption: | signal_byte(3) | iv
// | num_partitions | partition_offset * n | enc_data |
DCHECK_LT(subsamples.size(), kWebMMaxSubsamples);
const size_t num_partitions =
2 * subsamples.size() - 1 -
(subsamples.back().cipher_bytes == 0 ? 1 : 0);
const size_t header_size = kWebMSignalByteSize + iv_size +
kWebMNumPartitionsSize +
(kWebMPartitionOffsetSize * num_partitions);
const uint8_t signal_byte = kWebMEncryptedSignal | kWebMPartitionedSignal;
header_buffer->AppendInt(signal_byte);
header_buffer->AppendVector(decrypt_config->iv());
header_buffer->AppendInt(static_cast<uint8_t>(num_partitions));
uint32_t partition_offset = 0;
for (size_t i = 0; i < subsamples.size() - 1; ++i) {
partition_offset += subsamples[i].clear_bytes;
header_buffer->AppendInt(partition_offset);
partition_offset += subsamples[i].cipher_bytes;
header_buffer->AppendInt(partition_offset);
}
// Add another partition between the clear bytes and cipher bytes if
// cipher bytes is not zero.
if (subsamples.back().cipher_bytes != 0) {
partition_offset += subsamples.back().clear_bytes;
header_buffer->AppendInt(partition_offset);
}
DCHECK_EQ(header_size, header_buffer->Size());
} else {
// Use whole-frame encryption: | signal_byte(1) | iv | enc_data |
const uint8_t signal_byte = kWebMEncryptedSignal;
header_buffer->AppendInt(signal_byte);
header_buffer->AppendVector(decrypt_config->iv());
}
} else {
// Clear sample: | signal_byte(0) | data |
const uint8_t signal_byte = 0x00;
header_buffer->AppendInt(signal_byte);
}
}
} // namespace
Status UpdateTrackForEncryption(const std::vector<uint8_t>& key_id, Status UpdateTrackForEncryption(const std::vector<uint8_t>& key_id,
mkvmuxer::Track* track) { mkvmuxer::Track* track) {
@ -46,69 +97,17 @@ Status UpdateTrackForEncryption(const std::vector<uint8_t>& key_id,
} }
void UpdateFrameForEncryption(MediaSample* sample) { void UpdateFrameForEncryption(MediaSample* sample) {
const size_t sample_size = sample->data_size(); DCHECK(sample);
if (sample->decrypt_config()) { BufferWriter header_buffer;
auto* decrypt_config = sample->decrypt_config(); WriteEncryptedFrameHeader(sample->decrypt_config(), &header_buffer);
const size_t iv_size = decrypt_config->iv().size();
DCHECK_EQ(iv_size, kWebMIvSize);
if (!decrypt_config->subsamples().empty()) {
auto& subsamples = decrypt_config->subsamples();
// Use partitioned subsample encryption: | signal_byte(3) | iv
// | num_partitions | partition_offset * n | enc_data |
DCHECK_LT(subsamples.size(), kWebMMaxSubsamples);
const size_t num_partitions =
2 * subsamples.size() - 1 -
(subsamples.back().cipher_bytes == 0 ? 1 : 0);
const size_t header_size = kWebMSignalByteSize + iv_size +
kWebMNumPartitionsSize +
(kWebMPartitionOffsetSize * num_partitions);
sample->resize_data(header_size + sample_size);
uint8_t* sample_data = sample->writable_data();
memmove(sample_data + header_size, sample_data, sample_size);
sample_data[0] = kWebMEncryptedSignal | kWebMPartitionedSignal;
memcpy(sample_data + kWebMSignalByteSize, decrypt_config->iv().data(),
iv_size);
sample_data[kWebMSignalByteSize + kWebMIvSize] =
static_cast<uint8_t>(num_partitions);
BufferWriter offsets_buffer; const size_t sample_size = header_buffer.Size() + sample->data_size();
uint32_t partition_offset = 0; std::shared_ptr<uint8_t> new_sample_data(new uint8_t[sample_size],
for (size_t i = 0; i < subsamples.size() - 1; ++i) { std::default_delete<uint8_t[]>());
partition_offset += subsamples[i].clear_bytes; memcpy(new_sample_data.get(), header_buffer.Buffer(), header_buffer.Size());
offsets_buffer.AppendInt(partition_offset); memcpy(&new_sample_data.get()[header_buffer.Size()], sample->data(),
partition_offset += subsamples[i].cipher_bytes; sample->data_size());
offsets_buffer.AppendInt(partition_offset); sample->TransferData(std::move(new_sample_data), sample_size);
}
// Add another partition between the clear bytes and cipher bytes if
// cipher bytes is not zero.
if (subsamples.back().cipher_bytes != 0) {
partition_offset += subsamples.back().clear_bytes;
offsets_buffer.AppendInt(partition_offset);
}
DCHECK_EQ(num_partitions * kWebMPartitionOffsetSize,
offsets_buffer.Size());
memcpy(sample_data + kWebMSignalByteSize + kWebMIvSize +
kWebMNumPartitionsSize,
offsets_buffer.Buffer(), offsets_buffer.Size());
} else {
// Use whole-frame encryption: | signal_byte(1) | iv | enc_data |
sample->resize_data(sample_size + iv_size + kWebMSignalByteSize);
uint8_t* sample_data = sample->writable_data();
// First move the sample data to after the IV; then write the IV and
// signal byte.
memmove(sample_data + iv_size + kWebMSignalByteSize, sample_data,
sample_size);
sample_data[0] = kWebMEncryptedSignal;
memcpy(sample_data + 1, decrypt_config->iv().data(), iv_size);
}
} else {
// Clear sample: | signal_byte(0) | data |
sample->resize_data(sample_size + 1);
uint8_t* sample_data = sample->writable_data();
memmove(sample_data + 1, sample_data, sample_size);
sample_data[0] = 0x00;
}
} }
} // namespace webm } // namespace webm

View File

@ -376,21 +376,34 @@ bool WebMClusterParser::OnBlock(bool is_simple_block,
return false; return false;
} }
buffer = MediaSample::CopyFrom(data + data_offset, size - data_offset, const uint8_t* media_data = data + data_offset;
additional, additional_size, is_key_frame); const size_t media_data_size = size - data_offset;
// Use a dummy data size of 0 to avoid copying overhead.
// Actual media data is set later.
const size_t kDummyDataSize = 0;
buffer = MediaSample::CopyFrom(media_data, kDummyDataSize, additional,
additional_size, is_key_frame);
if (decrypt_config) { if (decrypt_config) {
if (!decryptor_source_) { if (!decryptor_source_) {
buffer->SetData(media_data, media_data_size);
// If the demuxer does not have the decryptor_source_, store // If the demuxer does not have the decryptor_source_, store
// decrypt_config so that the demuxed sample can be decrypted later. // decrypt_config so that the demuxed sample can be decrypted later.
buffer->set_decrypt_config(std::move(decrypt_config)); buffer->set_decrypt_config(std::move(decrypt_config));
buffer->set_is_encrypted(true); buffer->set_is_encrypted(true);
} else if (!decryptor_source_->DecryptSampleBuffer( } else {
decrypt_config.get(), buffer->writable_data(), std::shared_ptr<uint8_t> decrypted_media_data(
buffer->data_size())) { new uint8_t[media_data_size], std::default_delete<uint8_t[]>());
if (!decryptor_source_->DecryptSampleBuffer(
decrypt_config.get(), media_data, media_data_size,
decrypted_media_data.get())) {
LOG(ERROR) << "Cannot decrypt samples"; LOG(ERROR) << "Cannot decrypt samples";
return false; return false;
} }
buffer->TransferData(std::move(decrypted_media_data), media_data_size);
}
} else {
buffer->SetData(media_data, media_data_size);
} }
} else { } else {
std::string id, settings, content; std::string id, settings, content;

View File

@ -816,7 +816,7 @@ void WvmMediaParser::StartMediaSampleDemux() {
bool WvmMediaParser::Output(bool output_encrypted_sample) { bool WvmMediaParser::Output(bool output_encrypted_sample) {
if (output_encrypted_sample) { if (output_encrypted_sample) {
media_sample_->set_data(sample_data_.data(), sample_data_.size()); media_sample_->SetData(sample_data_.data(), sample_data_.size());
media_sample_->set_is_encrypted(true); media_sample_->set_is_encrypted(true);
} else { } else {
if ((prev_pes_stream_id_ & kPesStreamIdVideoMask) == kPesStreamIdVideo) { if ((prev_pes_stream_id_ & kPesStreamIdVideoMask) == kPesStreamIdVideo) {
@ -827,7 +827,7 @@ bool WvmMediaParser::Output(bool output_encrypted_sample) {
LOG(ERROR) << "Could not convert h.264 byte stream sample"; LOG(ERROR) << "Could not convert h.264 byte stream sample";
return false; return false;
} }
media_sample_->set_data(nal_unit_stream.data(), nal_unit_stream.size()); media_sample_->SetData(nal_unit_stream.data(), nal_unit_stream.size());
if (!is_initialized_) { if (!is_initialized_) {
// Set extra data for video stream from AVC Decoder Config Record. // Set extra data for video stream from AVC Decoder Config Record.
// Also, set codec string from the AVC Decoder Config Record. // Also, set codec string from the AVC Decoder Config Record.
@ -914,8 +914,7 @@ bool WvmMediaParser::Output(bool output_encrypted_sample) {
} }
size_t header_size = adts_header.GetAdtsHeaderSize(frame_ptr, size_t header_size = adts_header.GetAdtsHeaderSize(frame_ptr,
frame_size); frame_size);
media_sample_->set_data(frame_ptr + header_size, media_sample_->SetData(frame_ptr + header_size, frame_size - header_size);
frame_size - header_size);
if (!is_initialized_) { if (!is_initialized_) {
for (uint32_t i = 0; i < stream_infos_.size(); i++) { for (uint32_t i = 0; i < stream_infos_.size(); i++) {
if (stream_infos_[i]->stream_type() == kStreamAudio && if (stream_infos_[i]->stream_type() == kStreamAudio &&