diff --git a/packager/media/base/media_handler_test_base.cc b/packager/media/base/media_handler_test_base.cc index c54c8246c8..2a6ce3df65 100644 --- a/packager/media/base/media_handler_test_base.cc +++ b/packager/media/base/media_handler_test_base.cc @@ -235,7 +235,7 @@ Status MediaHandlerTestBase::SetUpAndInitializeGraph( // Add and connect all the requested outputs. for (size_t i = 0; i < output_count; i++) { - outputs_.emplace_back(new MockOutputMediaHandler); + outputs_.emplace_back(new testing::NiceMock); } for (auto& output : outputs_) { diff --git a/packager/media/base/media_handler_test_base.h b/packager/media/base/media_handler_test_base.h index 098bfcae97..eab512ef1a 100644 --- a/packager/media/base/media_handler_test_base.h +++ b/packager/media/base/media_handler_test_base.h @@ -13,6 +13,11 @@ namespace shaka { namespace media { +MATCHER_P(IsStreamInfo, stream_index, "") { + return arg->stream_index == stream_index && + arg->stream_data_type == StreamDataType::kStreamInfo; +} + MATCHER_P3(IsStreamInfo, stream_index, time_scale, encrypted, "") { *result_listener << "which is (" << stream_index << "," << time_scale << "," << (encrypted ? "encrypted" : "not encrypted") << ")"; diff --git a/packager/media/formats/webvtt/webvtt.gyp b/packager/media/formats/webvtt/webvtt.gyp index 337d0a7c26..4d53cc06c4 100644 --- a/packager/media/formats/webvtt/webvtt.gyp +++ b/packager/media/formats/webvtt/webvtt.gyp @@ -23,6 +23,8 @@ 'webvtt_parser.h', 'webvtt_sample_converter.cc', 'webvtt_sample_converter.h', + 'webvtt_segmenter.cc', + 'webvtt_segmenter.h', 'webvtt_timestamp.cc', 'webvtt_timestamp.h', ], @@ -41,6 +43,7 @@ 'webvtt_media_parser_unittest.cc', 'webvtt_parser_unittest.cc', 'webvtt_sample_converter_unittest.cc', + 'webvtt_segmenter_unittest.cc', 'webvtt_timestamp_unittest.cc', ], 'dependencies': [ diff --git a/packager/media/formats/webvtt/webvtt_segmenter.cc b/packager/media/formats/webvtt/webvtt_segmenter.cc new file mode 100644 index 0000000000..4612341d32 --- /dev/null +++ b/packager/media/formats/webvtt/webvtt_segmenter.cc @@ -0,0 +1,103 @@ +// Copyright 2017 Google Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +#include "packager/media/formats/webvtt/webvtt_segmenter.h" + +namespace shaka { +namespace media { +namespace { +const size_t kStreamIndex = 0; +} + +WebVttSegmenter::WebVttSegmenter(uint64_t segment_duration_ms) + : segment_duration_ms_(segment_duration_ms) {} + +Status WebVttSegmenter::InitializeInternal() { + return Status::OK; +} + +Status WebVttSegmenter::Process(std::unique_ptr stream_data) { + switch (stream_data->stream_data_type) { + case StreamDataType::kStreamInfo: + return DispatchStreamInfo(kStreamIndex, + std::move(stream_data->stream_info)); + case StreamDataType::kTextSample: + return OnTextSample(stream_data->text_sample); + default: + return Status(error::INTERNAL_ERROR, + "Invalid stream data type for this handler"); + } +} + +Status WebVttSegmenter::OnFlushRequest(size_t input_stream_index) { + Status status; + while (status.ok() && samples_.size()) { + status.Update(OnSegmentEnd()); + } + return status.ok() ? FlushAllDownstreams() : status; +} + +Status WebVttSegmenter::OnTextSample(std::shared_ptr sample) { + const uint64_t start_segment = sample->start_time() / segment_duration_ms_; + + // Find the last segment that overlaps the sample. Adjust the sample by one + // ms (smallest time unit) in case |EndTime| falls on the segment boundary. + DCHECK_GT(sample->duration(), 0u); + const uint64_t ending_segment = + (sample->EndTime() - 1) / segment_duration_ms_; + + DCHECK_GE(ending_segment, start_segment); + + // Samples must always be advancing. If a sample comes in out of order, + // skip the sample. + if (samples_.size() && samples_.top().segment > start_segment) { + LOG(WARNING) << "New sample has arrived out of order. Skipping sample " + << "as segment start is " << start_segment << " and segment " + << "head is " << samples_.top().segment << "."; + return Status::OK; + } + + for (uint64_t segment = start_segment; segment <= ending_segment; segment++) { + WebVttSegmentedTextSample seg_sample; + seg_sample.segment = segment; + seg_sample.sample = sample; + + samples_.push(seg_sample); + } + + Status status; + + while (status.ok() && samples_.size() && + samples_.top().segment < start_segment) { + // WriteNextSegment will pop elements from |samples_| which will + // eventually allow the loop to exit. + status.Update(OnSegmentEnd()); + } + + return status; +} + +Status WebVttSegmenter::OnSegmentEnd() { + DCHECK(samples_.size()); + + const uint64_t segment = samples_.top().segment; + + std::shared_ptr info = std::make_shared(); + info->start_timestamp = segment * segment_duration_ms_; + info->duration = segment_duration_ms_; + + Status status = DispatchSegmentInfo(kStreamIndex, std::move(info)); + + while (status.ok() && samples_.size() && samples_.top().segment == segment) { + status.Update( + DispatchTextSample(kStreamIndex, std::move(samples_.top().sample))); + samples_.pop(); + } + + return status; +} +} // namespace media +} // namespace shaka diff --git a/packager/media/formats/webvtt/webvtt_segmenter.h b/packager/media/formats/webvtt/webvtt_segmenter.h new file mode 100644 index 0000000000..743215493d --- /dev/null +++ b/packager/media/formats/webvtt/webvtt_segmenter.h @@ -0,0 +1,71 @@ +// Copyright 2017 Google Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +#ifndef PACKAGER_MEDIA_FORMATS_WEBVTT_WEBVTT_SEGMENTER_H_ +#define PACKAGER_MEDIA_FORMATS_WEBVTT_WEBVTT_SEGMENTER_H_ + +#include + +#include +#include + +#include "packager/media/base/media_handler.h" +#include "packager/status.h" + +namespace shaka { +namespace media { + +// Because a text sample can be in multiple segments, this struct +// allows us to associate a segment with a sample. This allows us +// to easily sort samples base on segment then time. +struct WebVttSegmentedTextSample { + uint64_t segment = 0; + std::shared_ptr sample; +}; + +class WebVttSegmentedTextSampleCompare { + public: + bool operator()(const WebVttSegmentedTextSample& left, + const WebVttSegmentedTextSample& right) const { + // If the samples are in the same segment, then the start time is the + // only way to order the two segments. + if (left.segment == right.segment) { + return left.sample->start_time() > right.sample->start_time(); + } + + // Time will not matter as the samples are not in the same segment. + return left.segment > right.segment; + } +}; + +class WebVttSegmenter : public MediaHandler { + public: + explicit WebVttSegmenter(uint64_t segment_duration_ms); + + protected: + Status Process(std::unique_ptr stream_data) override; + Status OnFlushRequest(size_t input_stream_index) override; + + private: + WebVttSegmenter(const WebVttSegmenter&) = delete; + WebVttSegmenter& operator=(const WebVttSegmenter&) = delete; + + Status InitializeInternal() override; + + Status OnTextSample(std::shared_ptr sample); + Status OnSegmentEnd(); + + uint64_t segment_duration_ms_; + std::priority_queue, + WebVttSegmentedTextSampleCompare> + samples_; +}; + +} // namespace media +} // namespace shaka + +#endif // PACKAGER_MEDIA_FORMATS_WEBVTT_WEBVTT_SEGMENTER_H_ diff --git a/packager/media/formats/webvtt/webvtt_segmenter_unittest.cc b/packager/media/formats/webvtt/webvtt_segmenter_unittest.cc new file mode 100644 index 0000000000..4c7df1beeb --- /dev/null +++ b/packager/media/formats/webvtt/webvtt_segmenter_unittest.cc @@ -0,0 +1,246 @@ +// Copyright 2017 Google Inc. All rights reserved. +// +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file or at +// https://developers.google.com/open-source/licenses/bsd + +#include +#include + +#include "packager/media/base/media_handler_test_base.h" +#include "packager/media/formats/webvtt/webvtt_segmenter.h" +#include "packager/status_test_util.h" + +namespace shaka { +namespace media { + +namespace { +const int64_t kStartTimeSigned = 0; +const uint64_t kStartTime = 0; +const int64_t kSegmentDuration = 10000; // 10 seconds + +const size_t kStreamIndex = 0; + +const size_t kInputCount = 1; +const size_t kOutputCount = 1; +const size_t kInputIndex = 0; +const size_t kOutputIndex = 0; + +const bool kEncrypted = true; +const bool kSubSegment = true; + +const char* kId[] = {"cue 1 id", "cue 2 id"}; +const char* kPayload[] = {"cue 1 payload", "cue 2 payload"}; + +const std::string kNoSettings = ""; +} // namespace + +class WebVttSegmenterTest : public MediaHandlerTestBase { + protected: + void SetUp() { + ASSERT_OK(SetUpAndInitializeGraph( + std::make_shared(kSegmentDuration), kInputCount, + kOutputCount)); + } +}; + +// When a cue ends on a segment boundry, it does not create a cue with a 0 ms +// duration +// | | +// |[---A---]| +// | | +TEST_F(WebVttSegmenterTest, CueEndingOnSegmentStart) { + const uint64_t kSampleDuration = kSegmentDuration; + + { + testing::InSequence s; + + EXPECT_CALL(*Output(kOutputIndex), OnProcess(IsStreamInfo(kStreamIndex))); + + // Segment One + EXPECT_CALL( + *Output(kOutputIndex), + OnProcess(IsSegmentInfo(kStreamIndex, kStartTimeSigned, + kSegmentDuration, !kSubSegment, !kEncrypted))); + EXPECT_CALL( + *Output(kOutputIndex), + OnProcess(IsTextSample(kId[0], kStartTime, kStartTime + kSampleDuration, + kNoSettings, kPayload[0]))); + + EXPECT_CALL(*Output(kOutputIndex), OnFlush(kStreamIndex)); + } + + ASSERT_OK(Input(kInputIndex) + ->Dispatch(StreamData::FromStreamInfo(kStreamIndex, + GetTextStreamInfo()))); + ASSERT_OK(Input(kInputIndex) + ->Dispatch(StreamData::FromTextSample( + kStreamIndex, + GetTextSample(kId[0], kStartTime, + kStartTime + kSampleDuration, kPayload[0])))); + ASSERT_OK(Input(kInputIndex)->FlushAllDownstreams()); +} + +// Each cue belongs in its own segment, so before each cue is passed +// downstream, a 'input of segment' message should be passed downstream. +// | +// [---A---] | +// | [---B---] +// | +TEST_F(WebVttSegmenterTest, CreatesSegmentsForCues) { + // Divide segment duration by 2 so that the sample duration won't be a full + // segment. + const uint64_t kSampleDuration = kSegmentDuration / 2; + + { + testing::InSequence s; + + EXPECT_CALL(*Output(kOutputIndex), OnProcess(IsStreamInfo(kStreamIndex))); + + // Segment One + EXPECT_CALL( + *Output(kOutputIndex), + OnProcess(IsSegmentInfo(kStreamIndex, kStartTimeSigned, + kSegmentDuration, !kSubSegment, !kEncrypted))); + EXPECT_CALL( + *Output(kOutputIndex), + OnProcess(IsTextSample(kId[0], kStartTime, kStartTime + kSampleDuration, + kNoSettings, kPayload[0]))); + + // Segment Two + EXPECT_CALL(*Output(kOutputIndex), + OnProcess(IsSegmentInfo( + kStreamIndex, kStartTimeSigned + kSegmentDuration, + kSegmentDuration, !kSubSegment, !kEncrypted))); + EXPECT_CALL( + *Output(kOutputIndex), + OnProcess(IsTextSample(kId[1], kStartTime + kSegmentDuration, + kStartTime + kSegmentDuration + kSampleDuration, + kNoSettings, kPayload[1]))); + + EXPECT_CALL(*Output(kOutputIndex), OnFlush(kStreamIndex)); + } + + ASSERT_OK(Input(kInputIndex) + ->Dispatch(StreamData::FromStreamInfo(kStreamIndex, + GetTextStreamInfo()))); + ASSERT_OK(Input(kInputIndex) + ->Dispatch(StreamData::FromTextSample( + kStreamIndex, + GetTextSample(kId[0], kStartTime, + kStartTime + kSampleDuration, kPayload[0])))); + ASSERT_OK( + Input(kInputIndex) + ->Dispatch(StreamData::FromTextSample( + kStreamIndex, + GetTextSample(kId[1], kStartTime + kSegmentDuration, + kStartTime + kSegmentDuration + kSampleDuration, + kPayload[1])))); + ASSERT_OK(Input(kInputIndex)->FlushAllDownstreams()); +} + +// [---A---] | | +// | | +// | | [---B---] +// | | +TEST_F(WebVttSegmenterTest, SkipsEmptySegments) { + const uint64_t kSampleDuration = kSegmentDuration / 2; + + { + testing::InSequence s; + + EXPECT_CALL(*Output(kOutputIndex), OnProcess(IsStreamInfo(kStreamIndex))); + + // Segment One + EXPECT_CALL( + *Output(kOutputIndex), + OnProcess(IsSegmentInfo(kStreamIndex, kStartTimeSigned, + kSegmentDuration, !kSubSegment, !kEncrypted))); + EXPECT_CALL( + *Output(kOutputIndex), + OnProcess(IsTextSample(kId[0], kStartTime, kStartTime + kSampleDuration, + kNoSettings, kPayload[0]))); + + // There is no segment two + + // Segment Three + EXPECT_CALL(*Output(kOutputIndex), + OnProcess(IsSegmentInfo( + kStreamIndex, kStartTimeSigned + 2 * kSegmentDuration, + kSegmentDuration, !kSubSegment, !kEncrypted))); + EXPECT_CALL(*Output(kOutputIndex), + OnProcess(IsTextSample( + kId[1], kStartTime + 2 * kSegmentDuration, + kStartTime + 2 * kSegmentDuration + kSampleDuration, + kNoSettings, kPayload[1]))); + + EXPECT_CALL(*Output(kOutputIndex), OnFlush(kStreamIndex)); + } + + ASSERT_OK(Input(kInputIndex) + ->Dispatch(StreamData::FromStreamInfo(kStreamIndex, + GetTextStreamInfo()))); + ASSERT_OK(Input(kInputIndex) + ->Dispatch(StreamData::FromTextSample( + kStreamIndex, + GetTextSample(kId[0], kStartTime, + kStartTime + kSampleDuration, kPayload[0])))); + ASSERT_OK( + Input(kInputIndex) + ->Dispatch(StreamData::FromTextSample( + kStreamIndex, + GetTextSample(kId[1], kStartTime + 2 * kSegmentDuration, + kStartTime + 2 * kSegmentDuration + kSampleDuration, + kPayload[1])))); + ASSERT_OK(Input(kInputIndex)->FlushAllDownstreams()); +} + +// When a cue crossing the segment boundary, the cue should be included in +// both segments. +// | +// [-----A-----|---------] +// | +TEST_F(WebVttSegmenterTest, CueCrossesSegments) { + const uint64_t kSampleDuration = 2 * kSegmentDuration; + + { + testing::InSequence s; + + EXPECT_CALL(*Output(kOutputIndex), OnProcess(IsStreamInfo(kStreamIndex))); + + // Segment One + EXPECT_CALL( + *Output(kOutputIndex), + OnProcess(IsSegmentInfo(kStreamIndex, kStartTimeSigned, + kSegmentDuration, !kSubSegment, !kEncrypted))); + EXPECT_CALL( + *Output(kOutputIndex), + OnProcess(IsTextSample(kId[0], kStartTime, kStartTime + kSampleDuration, + kNoSettings, kPayload[0]))); + + // Segment Two + EXPECT_CALL(*Output(kOutputIndex), + OnProcess(IsSegmentInfo( + kStreamIndex, kStartTimeSigned + kSegmentDuration, + kSegmentDuration, !kSubSegment, !kEncrypted))); + EXPECT_CALL( + *Output(kOutputIndex), + OnProcess(IsTextSample(kId[0], kStartTime, kStartTime + kSampleDuration, + kNoSettings, kPayload[0]))); + + EXPECT_CALL(*Output(kOutputIndex), OnFlush(kStreamIndex)); + } + + ASSERT_OK(Input(kInputIndex) + ->Dispatch(StreamData::FromStreamInfo(kStreamIndex, + GetTextStreamInfo()))); + ASSERT_OK(Input(kInputIndex) + ->Dispatch(StreamData::FromTextSample( + kStreamIndex, + GetTextSample(kId[0], kStartTime, + kStartTime + kSampleDuration, kPayload[0])))); + ASSERT_OK(Input(kInputIndex)->FlushAllDownstreams()); +} + +} // namespace media +} // namespace shaka