tesseract  5.0.0-alpha-619-ge9db
weightmatrix.h
Go to the documentation of this file.
1 // File: weightmatrix.h
3 // Description: Hides distinction between float/int implementations.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2014, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #ifndef TESSERACT_LSTM_WEIGHTMATRIX_H_
19 #define TESSERACT_LSTM_WEIGHTMATRIX_H_
20 
21 #include <memory>
23 #include "intsimdmatrix.h"
24 #include "matrix.h"
25 #include "tprintf.h"
26 
27 namespace tesseract {
28 
29 // Convenience instantiation of GENERIC_2D_ARRAY<double> with additional
30 // operations to write a strided vector, so the transposed form of the input
31 // is memory-contiguous.
32 class TransposedArray : public GENERIC_2D_ARRAY<double> {
33  public:
34  // Copies the whole input transposed, converted to double, into *this.
35  void Transpose(const GENERIC_2D_ARRAY<double>& input);
36  // Writes a vector of data representing a timestep (gradients or sources).
37  // The data is assumed to be of size1 in size (the strided dimension).
38  ~TransposedArray() override;
39  void WriteStrided(int t, const float* data) {
40  int size1 = dim1();
41  for (int i = 0; i < size1; ++i) put(i, t, data[i]);
42  }
43  void WriteStrided(int t, const double* data) {
44  int size1 = dim1();
45  for (int i = 0; i < size1; ++i) put(i, t, data[i]);
46  }
47  // Prints the first and last num elements of the un-transposed array.
48  void PrintUnTransposed(int num) {
49  int num_features = dim1();
50  int width = dim2();
51  for (int y = 0; y < num_features; ++y) {
52  for (int t = 0; t < width; ++t) {
53  if (num == 0 || t < num || t + num >= width) {
54  tprintf(" %g", (*this)(y, t));
55  }
56  }
57  tprintf("\n");
58  }
59  }
60 }; // class TransposedArray
61 
62 // Generic weight matrix for network layers. Can store the matrix as either
63 // an array of floats or int8_t. Provides functions to compute the forward and
64 // backward steps with the matrix and updates to the weights.
65 class WeightMatrix {
66  public:
67  WeightMatrix() : int_mode_(false), use_adam_(false) {}
68  // Sets up the network for training. Initializes weights using weights of
69  // scale `range` picked according to the random number generator `randomizer`.
70  // Note the order is outputs, inputs, as this is the order of indices to
71  // the matrix, so the adjacent elements are multiplied by the input during
72  // a forward operation.
73  int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range,
74  TRand* randomizer);
75  // Changes the number of outputs to the size of the given code_map, copying
76  // the old weight matrix entries for each output from code_map[output] where
77  // non-negative, and uses the mean (over all outputs) of the existing weights
78  // for all outputs with negative code_map entries. Returns the new number of
79  // weights.
80  int RemapOutputs(const std::vector<int>& code_map);
81 
82  // Converts a float network to an int network. Each set of input weights that
83  // corresponds to a single output weight is converted independently:
84  // Compute the max absolute value of the weight set.
85  // Scale so the max absolute value becomes INT8_MAX.
86  // Round to integer.
87  // Store a multiplicative scale factor (as a float) that will reproduce
88  // the original value, subject to rounding errors.
89  void ConvertToInt();
90  // Returns the size rounded up to an internal factor used by the SIMD
91  // implementation for its input.
92  int RoundInputs(int size) const {
93  if (!int_mode_ || !IntSimdMatrix::intSimdMatrix) return size;
95  }
96 
97  // Accessors.
98  bool is_int_mode() const {
99  return int_mode_;
100  }
101  int NumOutputs() const { return int_mode_ ? wi_.dim1() : wf_.dim1(); }
102  // Provides one set of weights. Only used by peep weight maxpool.
103  const double* GetWeights(int index) const { return wf_[index]; }
104  // Provides access to the deltas (dw_).
105  double GetDW(int i, int j) const { return dw_(i, j); }
106 
107  // Allocates any needed memory for running Backward, and zeroes the deltas,
108  // thus eliminating any existing momentum.
109  void InitBackward();
110 
111  // Writes to the given file. Returns false in case of error.
112  bool Serialize(bool training, TFile* fp) const;
113  // Reads from the given file. Returns false in case of error.
114  bool DeSerialize(bool training, TFile* fp);
115  // As DeSerialize, but reads an old (float) format WeightMatrix for
116  // backward compatibility.
117  bool DeSerializeOld(bool training, TFile* fp);
118 
119  // Computes matrix.vector v = Wu.
120  // u is of size W.dim2() - 1 and the output v is of size W.dim1().
121  // u is imagined to have an extra element at the end with value 1, to
122  // implement the bias, but it doesn't actually have it.
123  // Asserts that the call matches what we have.
124  void MatrixDotVector(const double* u, double* v) const;
125  void MatrixDotVector(const int8_t* u, double* v) const;
126  // MatrixDotVector for peep weights, MultiplyAccumulate adds the
127  // component-wise products of *this[0] and v to inout.
128  void MultiplyAccumulate(const double* v, double* inout);
129  // Computes vector.matrix v = uW.
130  // u is of size W.dim1() and the output v is of size W.dim2() - 1.
131  // The last result is discarded, as v is assumed to have an imaginary
132  // last value of 1, as with MatrixDotVector.
133  void VectorDotMatrix(const double* u, double* v) const;
134  // Fills dw_[i][j] with the dot product u[i][] . v[j][], using elements
135  // from u and v, starting with u[i][offset] and v[j][offset].
136  // Note that (matching MatrixDotVector) v[last][] is missing, presumed 1.0.
137  // Runs parallel if requested. Note that inputs must be transposed.
138  void SumOuterTransposed(const TransposedArray& u, const TransposedArray& v,
139  bool parallel);
140  // Updates the weights using the given learning rate, momentum and adam_beta.
141  // num_samples is used in the Adam correction factor.
142  void Update(double learning_rate, double momentum, double adam_beta,
143  int num_samples);
144  // Adds the dw_ in other to the dw_ is *this.
145  void AddDeltas(const WeightMatrix& other);
146  // Sums the products of weight updates in *this and other, splitting into
147  // positive (same direction) in *same and negative (different direction) in
148  // *changed.
149  void CountAlternators(const WeightMatrix& other, double* same,
150  double* changed) const;
151 
152  void Debug2D(const char* msg);
153 
154  // Utility function converts an array of float to the corresponding array
155  // of double.
156  static void FloatToDouble(const GENERIC_2D_ARRAY<float>& wf,
158 
159  private:
160  // Choice between float and 8 bit int implementations.
163  // Transposed copy of wf_, used only for Backward, and set with each Update.
164  TransposedArray wf_t_;
165  // Which of wf_ and wi_ are we actually using.
166  bool int_mode_;
167  // True if we are running adam in this weight matrix.
168  bool use_adam_;
169  // If we are using wi_, then scales_ is a factor to restore the row product
170  // with a vector to the correct range.
171  GenericVector<double> scales_;
172  // Weight deltas. dw_ is the new delta, and updates_ the momentum-decaying
173  // amount to be added to wf_/wi_.
175  GENERIC_2D_ARRAY<double> updates_;
176  // Iff use_adam_, the sum of squares of dw_. The number of samples is
177  // given to Update(). Serialized iff use_adam_.
178  GENERIC_2D_ARRAY<double> dw_sq_sum_;
179  // The weights matrix reorganized in whatever way suits this instance.
180  std::vector<int8_t> shaped_w_;
181 };
182 
183 } // namespace tesseract.
184 
185 #endif // TESSERACT_LSTM_WEIGHTMATRIX_H_
tesseract::WeightMatrix::CountAlternators
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
Definition: weightmatrix.cpp:346
tesseract::TransposedArray::~TransposedArray
~TransposedArray() override
tesseract::WeightMatrix::AddDeltas
void AddDeltas(const WeightMatrix &other)
Definition: weightmatrix.cpp:337
tesseract::WeightMatrix::SumOuterTransposed
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
Definition: weightmatrix.cpp:284
tesseract::WeightMatrix::WeightMatrix
WeightMatrix()
Definition: weightmatrix.h:67
tesseract::TransposedArray::WriteStrided
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:39
tesseract::WeightMatrix::Update
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
Definition: weightmatrix.cpp:314
tesseract::TransposedArray::PrintUnTransposed
void PrintUnTransposed(int num)
Definition: weightmatrix.h:48
tesseract::WeightMatrix
Definition: weightmatrix.h:65
GENERIC_2D_ARRAY
Definition: intsimdmatrix.h:26
tesseract::WeightMatrix::Debug2D
void Debug2D(const char *msg)
Definition: weightmatrix.cpp:377
tesseract::WeightMatrix::InitBackward
void InitBackward()
Definition: weightmatrix.cpp:153
genericvector.h
GENERIC_2D_ARRAY< double >::dim2
int dim2() const
Definition: matrix.h:206
tesseract::TransposedArray::Transpose
void Transpose(const GENERIC_2D_ARRAY< double > &input)
Definition: weightmatrix.cpp:62
tesseract::WeightMatrix::RoundInputs
int RoundInputs(int size) const
Definition: weightmatrix.h:92
tesseract::WeightMatrix::DeSerialize
bool DeSerialize(bool training, TFile *fp)
Definition: weightmatrix.cpp:191
tesseract::WeightMatrix::Serialize
bool Serialize(bool training, TFile *fp) const
Definition: weightmatrix.cpp:172
matrix.h
tesseract::TFile
Definition: serialis.h:75
tesseract::IntSimdMatrix::intSimdMatrix
static const IntSimdMatrix * intSimdMatrix
Definition: intsimdmatrix.h:116
tesseract::IntSimdMatrix::RoundInputs
int RoundInputs(int size) const
Definition: intsimdmatrix.h:69
tesseract::WeightMatrix::DeSerializeOld
bool DeSerializeOld(bool training, TFile *fp)
Definition: weightmatrix.cpp:216
tesseract::WeightMatrix::ConvertToInt
void ConvertToInt()
Definition: weightmatrix.cpp:125
tesseract::WeightMatrix::InitWeightsFloat
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
Definition: weightmatrix.cpp:76
tesseract
Definition: baseapi.h:65
tprintf.h
GenericVector< double >
tesseract::WeightMatrix::RemapOutputs
int RemapOutputs(const std::vector< int > &code_map)
Definition: weightmatrix.cpp:97
tesseract::TransposedArray::WriteStrided
void WriteStrided(int t, const double *data)
Definition: weightmatrix.h:43
tesseract::WeightMatrix::MatrixDotVector
void MatrixDotVector(const double *u, double *v) const
Definition: weightmatrix.cpp:243
tesseract::TransposedArray
Definition: weightmatrix.h:32
tesseract::WeightMatrix::MultiplyAccumulate
void MultiplyAccumulate(const double *v, double *inout)
Definition: weightmatrix.cpp:260
tesseract::WeightMatrix::GetWeights
const double * GetWeights(int index) const
Definition: weightmatrix.h:103
GENERIC_2D_ARRAY< double >::put
void put(ICOORD pos, const double &thing)
Definition: matrix.h:219
tesseract::WeightMatrix::VectorDotMatrix
void VectorDotMatrix(const double *u, double *v) const
Definition: weightmatrix.cpp:274
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
intsimdmatrix.h
tesseract::WeightMatrix::NumOutputs
int NumOutputs() const
Definition: weightmatrix.h:101
tesseract::WeightMatrix::GetDW
double GetDW(int i, int j) const
Definition: weightmatrix.h:105
tesseract::WeightMatrix::FloatToDouble
static void FloatToDouble(const GENERIC_2D_ARRAY< float > &wf, GENERIC_2D_ARRAY< double > *wd)
Definition: weightmatrix.cpp:399
tesseract::WeightMatrix::is_int_mode
bool is_int_mode() const
Definition: weightmatrix.h:98
tesseract::TRand
Definition: helpers.h:50
GENERIC_2D_ARRAY< double >::dim1
int dim1() const
Definition: matrix.h:205