Shaka Packager SDK
threaded_io_file.cc
1 // Copyright 2015 Google Inc. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file or at
5 // https://developers.google.com/open-source/licenses/bsd
6 
7 #include "packager/file/threaded_io_file.h"
8 
9 #include "packager/base/bind.h"
10 #include "packager/base/bind_helpers.h"
11 #include "packager/base/location.h"
12 #include "packager/base/threading/worker_pool.h"
13 
14 namespace shaka {
15 
16 ThreadedIoFile::ThreadedIoFile(std::unique_ptr<File, FileCloser> internal_file,
17  Mode mode,
18  uint64_t io_cache_size,
19  uint64_t io_block_size)
20  : File(internal_file->file_name()),
21  internal_file_(std::move(internal_file)),
22  mode_(mode),
23  cache_(io_cache_size),
24  io_buffer_(io_block_size),
25  position_(0),
26  size_(0),
27  eof_(false),
28  flushing_(false),
29  flush_complete_event_(base::WaitableEvent::ResetPolicy::AUTOMATIC,
30  base::WaitableEvent::InitialState::NOT_SIGNALED),
31  internal_file_error_(0),
32  task_exit_event_(base::WaitableEvent::ResetPolicy::AUTOMATIC,
33  base::WaitableEvent::InitialState::NOT_SIGNALED) {
34  DCHECK(internal_file_);
35 }
36 
37 ThreadedIoFile::~ThreadedIoFile() {}
38 
40  DCHECK(internal_file_);
41 
42  if (!internal_file_->Open())
43  return false;
44 
45  position_ = 0;
46  size_ = internal_file_->Size();
47 
48  base::WorkerPool::PostTask(
49  FROM_HERE,
50  base::Bind(&ThreadedIoFile::TaskHandler, base::Unretained(this)),
51  true /* task_is_slow */);
52  return true;
53 }
54 
56  DCHECK(internal_file_);
57 
58  bool result = true;
59  if (mode_ == kOutputMode)
60  result = Flush();
61 
62  cache_.Close();
63  task_exit_event_.Wait();
64 
65  result &= internal_file_.release()->Close();
66  delete this;
67  return result;
68 }
69 
70 int64_t ThreadedIoFile::Read(void* buffer, uint64_t length) {
71  DCHECK(internal_file_);
72  DCHECK_EQ(kInputMode, mode_);
73 
74  if (eof_.load(std::memory_order_relaxed) && !cache_.BytesCached())
75  return 0;
76 
77  if (internal_file_error_.load(std::memory_order_relaxed))
78  return internal_file_error_.load(std::memory_order_relaxed);
79 
80  uint64_t bytes_read = cache_.Read(buffer, length);
81  position_ += bytes_read;
82 
83  return bytes_read;
84 }
85 
86 int64_t ThreadedIoFile::Write(const void* buffer, uint64_t length) {
87  DCHECK(internal_file_);
88  DCHECK_EQ(kOutputMode, mode_);
89 
90  if (internal_file_error_.load(std::memory_order_relaxed))
91  return internal_file_error_.load(std::memory_order_relaxed);
92 
93  uint64_t bytes_written = cache_.Write(buffer, length);
94  position_ += bytes_written;
95  if (position_ > size_)
96  size_ = position_;
97 
98  return bytes_written;
99 }
100 
102  DCHECK(internal_file_);
103 
104  return size_;
105 }
106 
108  DCHECK(internal_file_);
109  DCHECK_EQ(kOutputMode, mode_);
110 
111  if (internal_file_error_.load(std::memory_order_relaxed))
112  return false;
113 
114  flushing_ = true;
115  cache_.Close();
116  flush_complete_event_.Wait();
117  return internal_file_->Flush();
118 }
119 
120 bool ThreadedIoFile::Seek(uint64_t position) {
121  if (mode_ == kOutputMode) {
122  // Writing. Just flush the cache and seek.
123  if (!Flush())
124  return false;
125  if (!internal_file_->Seek(position))
126  return false;
127  } else {
128  // Reading. Close cache, wait for thread task to exit, seek, and re-post
129  // the task.
130  cache_.Close();
131  task_exit_event_.Wait();
132  bool result = internal_file_->Seek(position);
133  if (!result) {
134  // Seek failed. Seek to logical position instead.
135  if (!internal_file_->Seek(position_) && (position != position_)) {
136  LOG(WARNING) << "Seek failed. ThreadedIoFile left in invalid state.";
137  }
138  }
139  cache_.Reopen();
140  eof_ = false;
141  base::WorkerPool::PostTask(
142  FROM_HERE,
143  base::Bind(&ThreadedIoFile::TaskHandler, base::Unretained(this)),
144  true /* task_is_slow */);
145  if (!result)
146  return false;
147  }
148  position_ = position;
149  return true;
150 }
151 
152 bool ThreadedIoFile::Tell(uint64_t* position) {
153  DCHECK(position);
154 
155  *position = position_;
156  return true;
157 }
158 
159 void ThreadedIoFile::TaskHandler() {
160  if (mode_ == kInputMode)
161  RunInInputMode();
162  else
163  RunInOutputMode();
164  task_exit_event_.Signal();
165 }
166 
167 void ThreadedIoFile::RunInInputMode() {
168  DCHECK(internal_file_);
169  DCHECK_EQ(kInputMode, mode_);
170 
171  while (true) {
172  int64_t read_result =
173  internal_file_->Read(&io_buffer_[0], io_buffer_.size());
174  if (read_result <= 0) {
175  eof_.store(read_result == 0, std::memory_order_relaxed);
176  internal_file_error_.store(read_result, std::memory_order_relaxed);
177  cache_.Close();
178  return;
179  }
180  if (cache_.Write(&io_buffer_[0], read_result) == 0) {
181  return;
182  }
183  }
184 }
185 
186 void ThreadedIoFile::RunInOutputMode() {
187  DCHECK(internal_file_);
188  DCHECK_EQ(kOutputMode, mode_);
189 
190  while (true) {
191  uint64_t write_bytes = cache_.Read(&io_buffer_[0], io_buffer_.size());
192  if (write_bytes == 0) {
193  if (flushing_) {
194  cache_.Reopen();
195  flushing_ = false;
196  flush_complete_event_.Signal();
197  } else {
198  return;
199  }
200  } else {
201  uint64_t bytes_written(0);
202  while (bytes_written < write_bytes) {
203  int64_t write_result = internal_file_->Write(
204  &io_buffer_[bytes_written], write_bytes - bytes_written);
205  if (write_result < 0) {
206  internal_file_error_.store(write_result, std::memory_order_relaxed);
207  cache_.Close();
208  if (flushing_) {
209  flushing_ = false;
210  flush_complete_event_.Signal();
211  }
212  return;
213  }
214  bytes_written += write_result;
215  }
216  }
217  }
218 }
219 
220 } // namespace shaka
int64_t Size() override
int64_t Write(const void *buffer, uint64_t length) override
STL namespace.
All the methods that are virtual are virtual for mocking.
bool Open() override
Internal open. Should not be used directly.
bool Seek(uint64_t position) override
bool Tell(uint64_t *position) override
int64_t Read(void *buffer, uint64_t length) override