tesseract  5.0.0-alpha-619-ge9db
networkscratch.h
Go to the documentation of this file.
1 // File: networkscratch.h
3 // Description: Scratch space for Network layers that hides distinction
4 // between float/int implementations.
5 // Author: Ray Smith
6 //
7 // (C) Copyright 2014, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_NETWORKSCRATCH_H_
20 #define TESSERACT_LSTM_NETWORKSCRATCH_H_
21 
22 #include <mutex>
24 #include "matrix.h"
25 #include "networkio.h"
26 
27 namespace tesseract {
28 
29 // Generic scratch space for network layers. Provides NetworkIO that can store
30 // a complete set (over time) of intermediates, and GenericVector<float>
31 // scratch space that auto-frees after use. The aim here is to provide a set
32 // of temporary buffers to network layers that can be reused between layers
33 // and don't have to be reallocated on each call.
35  public:
36  NetworkScratch() : int_mode_(false) {}
37  ~NetworkScratch() = default;
38 
39  // Sets the network representation. If the representation is integer, then
40  // default (integer) NetworkIOs are separated from the always-float variety.
41  // This saves memory by having separate int-specific and float-specific
42  // stacks. If the network representation is float, then all NetworkIOs go
43  // to the float stack.
44  void set_int_mode(bool int_mode) {
45  int_mode_ = int_mode;
46  }
47 
48  // Class that acts like a NetworkIO (by having an implicit cast operator),
49  // yet actually holds a pointer to NetworkIOs in the source NetworkScratch,
50  // and knows how to unstack the borrowed pointers on destruction.
51  class IO {
52  public:
53  // The NetworkIO should be sized after construction.
54  IO(const NetworkIO& src, NetworkScratch* scratch)
55  : int_mode_(scratch->int_mode_ && src.int_mode()),
56  scratch_space_(scratch) {
57  network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
58  : scratch_space_->float_stack_.Borrow();
59  }
60  // Default constructor for arrays. Use one of the Resize functions
61  // below to initialize and size.
62  IO() : int_mode_(false), network_io_(nullptr), scratch_space_(nullptr) {}
63 
64  ~IO() {
65  if (scratch_space_ == nullptr) {
66  ASSERT_HOST(network_io_ == nullptr);
67  } else if (int_mode_) {
68  scratch_space_->int_stack_.Return(network_io_);
69  } else {
70  scratch_space_->float_stack_.Return(network_io_);
71  }
72  }
73  // Resizes the array (and stride), avoiding realloc if possible, to the
74  // size from various size specs:
75  // Same time size, given number of features.
76  void Resize(const NetworkIO& src, int num_features,
77  NetworkScratch* scratch) {
78  if (scratch_space_ == nullptr) {
79  int_mode_ = scratch->int_mode_ && src.int_mode();
80  scratch_space_ = scratch;
81  network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
82  : scratch_space_->float_stack_.Borrow();
83  }
84  network_io_->Resize(src, num_features);
85  }
86  // Resizes to a specific size as a temp buffer. No batches, no y-dim.
87  void Resize2d(bool int_mode, int width, int num_features,
88  NetworkScratch* scratch) {
89  if (scratch_space_ == nullptr) {
90  int_mode_ = scratch->int_mode_ && int_mode;
91  scratch_space_ = scratch;
92  network_io_ = int_mode_ ? scratch_space_->int_stack_.Borrow()
93  : scratch_space_->float_stack_.Borrow();
94  }
95  network_io_->Resize2d(int_mode, width, num_features);
96  }
97  // Resize forcing a float representation with the width of src and the given
98  // number of features.
99  void ResizeFloat(const NetworkIO& src, int num_features,
100  NetworkScratch* scratch) {
101  if (scratch_space_ == nullptr) {
102  int_mode_ = false;
103  scratch_space_ = scratch;
104  network_io_ = scratch_space_->float_stack_.Borrow();
105  }
106  network_io_->ResizeFloat(src, num_features);
107  }
108 
109  // Returns a ref to a NetworkIO that enables *this to be treated as if
110  // it were just a NetworkIO*.
112  return *network_io_;
113  }
115  return network_io_;
116  }
117  operator NetworkIO*() {
118  return network_io_;
119  }
120 
121  private:
122  // True if this is from the always-float stack, otherwise the default stack.
123  bool int_mode_;
124  // The NetworkIO that we have borrowed from the scratch_space_.
125  NetworkIO* network_io_;
126  // The source scratch_space_. Borrowed pointer, used to free the
127  // NetworkIO. Don't delete!
128  NetworkScratch* scratch_space_;
129  }; // class IO.
130 
131  // Class that acts like a fixed array of float, yet actually uses space
132  // from a GenericVector<float> in the source NetworkScratch, and knows how
133  // to unstack the borrowed vector on destruction.
134  class FloatVec {
135  public:
136  // The array will have size elements in it, uninitialized.
137  FloatVec(int size, NetworkScratch* scratch)
138  : vec_(nullptr), scratch_space_(scratch) {
139  Init(size, scratch);
140  }
141  // Default constructor is for arrays. Use Init to setup.
142  FloatVec() : vec_(nullptr), data_(nullptr), scratch_space_(nullptr) {}
144  if (scratch_space_ != nullptr) scratch_space_->vec_stack_.Return(vec_);
145  }
146 
147  void Init(int size, NetworkScratch* scratch) {
148  if (scratch_space_ != nullptr && vec_ != nullptr)
149  scratch_space_->vec_stack_.Return(vec_);
150  scratch_space_ = scratch;
151  vec_ = scratch_space_->vec_stack_.Borrow();
152  vec_->resize_no_init(size);
153  data_ = &(*vec_)[0];
154  }
155 
156  // Use the cast operator instead of operator[] so the FloatVec can be used
157  // as a double* argument to a function call.
158  operator double*() const { return data_; }
159  double* get() { return data_; }
160 
161  private:
162  // Vector borrowed from the scratch space. Use Return to free it.
163  GenericVector<double>* vec_;
164  // Short-cut pointer to the underlying array.
165  double* data_;
166  // The source scratch_space_. Borrowed pointer, used to free the
167  // vector. Don't delete!
168  NetworkScratch* scratch_space_;
169  }; // class FloatVec
170 
171  // Class that acts like a 2-D array of double, yet actually uses space
172  // from the source NetworkScratch, and knows how to unstack the borrowed
173  // array on destruction.
175  public:
176  // Default constructor is for arrays. Use Init to setup.
177  GradientStore() : array_(nullptr), scratch_space_(nullptr) {}
179  if (scratch_space_ != nullptr) scratch_space_->array_stack_.Return(array_);
180  }
181 
182  void Init(int size1, int size2, NetworkScratch* scratch) {
183  if (scratch_space_ != nullptr && array_ != nullptr)
184  scratch_space_->array_stack_.Return(array_);
185  scratch_space_ = scratch;
186  array_ = scratch_space_->array_stack_.Borrow();
187  array_->Resize(size1, size2, 0.0);
188  }
189 
190  // Accessors to get to the underlying TransposedArray.
191  TransposedArray* get() const { return array_; }
192  const TransposedArray& operator*() const { return *array_; }
193 
194  private:
195  // Array borrowed from the scratch space. Use Return to free it.
196  TransposedArray* array_;
197  // The source scratch_space_. Borrowed pointer, used to free the
198  // vector. Don't delete!
199  NetworkScratch* scratch_space_;
200  }; // class GradientStore
201 
202  // Class that does the work of holding a stack of objects, a stack pointer
203  // and a vector of in-use flags, so objects can be returned out of order.
204  // It is safe to attempt to Borrow/Return in multiple threads.
205  template<typename T> class Stack {
206  public:
207  Stack() : stack_top_(0) {
208  }
209 
210  // Lends out the next free item, creating one if none available, sets
211  // the used flags and increments the stack top.
212  T* Borrow() {
213  std::lock_guard<std::mutex> lock(mutex_);
214  if (stack_top_ == stack_.size()) {
215  stack_.push_back(new T);
216  flags_.push_back(false);
217  }
218  flags_[stack_top_] = true;
219  return stack_[stack_top_++];
220  }
221  // Takes back the given item, and marks it free. Item does not have to be
222  // the most recently lent out, but free slots don't get re-used until the
223  // blocking item is returned. The assumption is that there will only be
224  // small, temporary variations from true stack use. (Determined by the order
225  // of destructors within a local scope.)
226  void Return(T* item) {
227  std::lock_guard<std::mutex> lock(mutex_);
228  // Linear search will do.
229  int index = stack_top_ - 1;
230  while (index >= 0 && stack_[index] != item) --index;
231  if (index >= 0) flags_[index] = false;
232  while (stack_top_ > 0 && !flags_[stack_top_ - 1]) --stack_top_;
233  }
234 
235  private:
236  PointerVector<T> stack_;
237  GenericVector<bool> flags_;
238  int stack_top_;
239  std::mutex mutex_;
240  }; // class Stack.
241 
242  private:
243  // If true, the network weights are int8_t, if false, float.
244  bool int_mode_;
245  // Stacks of NetworkIO and GenericVector<float>. Once allocated, they are not
246  // deleted until the NetworkScratch is deleted.
247  Stack<NetworkIO> int_stack_;
248  Stack<NetworkIO> float_stack_;
249  Stack<GenericVector<double> > vec_stack_;
250  Stack<TransposedArray> array_stack_;
251 };
252 
253 } // namespace tesseract.
254 
255 #endif // TESSERACT_LSTM_NETWORKSCRATCH_H_
tesseract::NetworkScratch::GradientStore::~GradientStore
~GradientStore()
Definition: networkscratch.h:178
tesseract::NetworkIO::int_mode
bool int_mode() const
Definition: networkio.h:127
tesseract::NetworkScratch::Stack::Borrow
T * Borrow()
Definition: networkscratch.h:212
networkio.h
tesseract::NetworkScratch::~NetworkScratch
~NetworkScratch()=default
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:87
tesseract::NetworkScratch::FloatVec
Definition: networkscratch.h:134
tesseract::NetworkScratch::Stack::Return
void Return(T *item)
Definition: networkscratch.h:226
tesseract::NetworkScratch::set_int_mode
void set_int_mode(bool int_mode)
Definition: networkscratch.h:44
tesseract::PointerVector
Definition: genericvector.h:417
tesseract::NetworkScratch
Definition: networkscratch.h:34
tesseract::NetworkScratch::IO::operator->
NetworkIO * operator->()
Definition: networkscratch.h:114
tesseract::NetworkScratch::IO::ResizeFloat
void ResizeFloat(const NetworkIO &src, int num_features, NetworkScratch *scratch)
Definition: networkscratch.h:99
tesseract::NetworkIO::ResizeFloat
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:52
tesseract::NetworkScratch::GradientStore::Init
void Init(int size1, int size2, NetworkScratch *scratch)
Definition: networkscratch.h:182
tesseract::NetworkScratch::IO::IO
IO()
Definition: networkscratch.h:62
tesseract::NetworkScratch::GradientStore::get
TransposedArray * get() const
Definition: networkscratch.h:191
tesseract::NetworkScratch::IO::IO
IO(const NetworkIO &src, NetworkScratch *scratch)
Definition: networkscratch.h:54
genericvector.h
tesseract::NetworkIO::Resize2d
void Resize2d(bool int_mode, int width, int num_features)
Definition: networkio.cpp:35
tesseract::NetworkScratch::GradientStore::GradientStore
GradientStore()
Definition: networkscratch.h:177
tesseract::NetworkScratch::GradientStore::operator*
const TransposedArray & operator*() const
Definition: networkscratch.h:192
GenericVector::push_back
int push_back(T object)
Definition: genericvector.h:799
tesseract::NetworkScratch::IO::Resize2d
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
Definition: networkscratch.h:87
tesseract::NetworkScratch::Stack
Definition: networkscratch.h:205
tesseract::NetworkScratch::FloatVec::get
double * get()
Definition: networkscratch.h:159
matrix.h
GenericVector::resize_no_init
void resize_no_init(int size)
Definition: genericvector.h:65
tesseract::NetworkScratch::FloatVec::FloatVec
FloatVec(int size, NetworkScratch *scratch)
Definition: networkscratch.h:137
tesseract::NetworkIO
Definition: networkio.h:39
tesseract::NetworkScratch::FloatVec::Init
void Init(int size, NetworkScratch *scratch)
Definition: networkscratch.h:147
tesseract::NetworkScratch::GradientStore
Definition: networkscratch.h:174
tesseract
Definition: baseapi.h:65
tesseract::NetworkScratch::IO::operator*
NetworkIO & operator*()
Definition: networkscratch.h:111
GenericVector< double >
tesseract::NetworkScratch::Stack::Stack
Stack()
Definition: networkscratch.h:207
tesseract::NetworkScratch::IO
Definition: networkscratch.h:51
tesseract::NetworkScratch::FloatVec::FloatVec
FloatVec()
Definition: networkscratch.h:142
tesseract::NetworkIO::Resize
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
tesseract::TransposedArray
Definition: weightmatrix.h:32
tesseract::NetworkScratch::FloatVec::~FloatVec
~FloatVec()
Definition: networkscratch.h:143
GENERIC_2D_ARRAY::Resize
void Resize(int size1, int size2, const T &empty)
Definition: matrix.h:104
tesseract::NetworkScratch::IO::Resize
void Resize(const NetworkIO &src, int num_features, NetworkScratch *scratch)
Definition: networkscratch.h:76
tesseract::NetworkScratch::IO::~IO
~IO()
Definition: networkscratch.h:64
tesseract::NetworkScratch::NetworkScratch
NetworkScratch()
Definition: networkscratch.h:36