tesseract  5.0.0-alpha-619-ge9db
fullyconnected.cpp
Go to the documentation of this file.
1 // File: fullyconnected.cpp
3 // Description: Simple feed-forward layer with various non-linearities.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2014, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #include "fullyconnected.h"
19 
20 #ifdef _OPENMP
21 #include <omp.h>
22 #endif
23 #include <cstdio>
24 #include <cstdlib>
25 
26 #include "functions.h"
27 #include "networkscratch.h"
28 
29 // Number of threads to use for parallel calculation of Forward and Backward.
30 #ifdef _OPENMP
31 const int kNumThreads = 4;
32 #else
33 const int kNumThreads = 1;
34 #endif
35 
36 namespace tesseract {
37 
38 FullyConnected::FullyConnected(const STRING& name, int ni, int no,
40  : Network(type, name, ni, no), external_source_(nullptr), int_mode_(false) {
41 }
42 
43 // Returns the shape output from the network given an input shape (which may
44 // be partially unknown ie zero).
46  LossType loss_type = LT_NONE;
47  if (type_ == NT_SOFTMAX)
48  loss_type = LT_CTC;
49  else if (type_ == NT_SOFTMAX_NO_CTC)
50  loss_type = LT_SOFTMAX;
51  else if (type_ == NT_LOGISTIC)
52  loss_type = LT_LOGISTIC;
53  StaticShape result(input_shape);
54  result.set_depth(no_);
55  result.set_loss_type(loss_type);
56  return result;
57 }
58 
59 // Suspends/Enables training by setting the training_ flag.
61  if (state == TS_RE_ENABLE) {
62  // Enable only from temp disabled.
64  } else if (state == TS_TEMP_DISABLE) {
65  // Temp disable only from enabled.
66  if (training_ == TS_ENABLED) training_ = state;
67  } else {
68  if (state == TS_ENABLED && training_ != TS_ENABLED)
70  training_ = state;
71  }
72 }
73 
74 // Sets up the network for training. Initializes weights using weights of
75 // scale `range` picked according to the random number generator `randomizer`.
76 int FullyConnected::InitWeights(float range, TRand* randomizer) {
77  Network::SetRandomizer(randomizer);
79  range, randomizer);
80  return num_weights_;
81 }
82 
83 // Recursively searches the network for softmaxes with old_no outputs,
84 // and remaps their outputs according to code_map. See network.h for details.
85 
86 int FullyConnected::RemapOutputs(int old_no, const std::vector<int>& code_map) {
87  if (type_ == NT_SOFTMAX && no_ == old_no) {
89  no_ = code_map.size();
90  }
91  return num_weights_;
92 }
93 
94 // Converts a float network to an int network.
97 }
98 
99 // Provides debug output on the weights.
102 }
103 
104 // Writes to the given file. Returns false in case of error.
106  if (!Network::Serialize(fp)) return false;
107  if (!weights_.Serialize(IsTraining(), fp)) return false;
108  return true;
109 }
110 
111 // Reads from the given file. Returns false in case of error.
113  return weights_.DeSerialize(IsTraining(), fp);
114 }
115 
116 // Runs forward propagation of activations on the input line.
117 // See NetworkCpp for a detailed discussion of the arguments.
118 void FullyConnected::Forward(bool debug, const NetworkIO& input,
119  const TransposedArray* input_transpose,
120  NetworkScratch* scratch, NetworkIO* output) {
121  int width = input.Width();
122  if (type_ == NT_SOFTMAX)
123  output->ResizeFloat(input, no_);
124  else
125  output->Resize(input, no_);
126  SetupForward(input, input_transpose);
131  for (int i = 0; i < kNumThreads; ++i) {
132  temp_lines[i].Init(no_, scratch);
133  curr_input[i].Init(ni_, scratch);
134  }
135 #ifdef _OPENMP
136 #pragma omp parallel for num_threads(kNumThreads)
137  for (int t = 0; t < width; ++t) {
138  // Thread-local pointer to temporary storage.
139  int thread_id = omp_get_thread_num();
140 #else
141  for (int t = 0; t < width; ++t) {
142  // Thread-local pointer to temporary storage.
143  int thread_id = 0;
144 #endif
145  double* temp_line = temp_lines[thread_id];
146  if (input.int_mode()) {
147  ForwardTimeStep(input.i(t), t, temp_line);
148  } else {
149  input.ReadTimeStep(t, curr_input[thread_id]);
150  ForwardTimeStep(curr_input[thread_id], t, temp_line);
151  }
152  output->WriteTimeStep(t, temp_line);
153  if (IsTraining() && type_ != NT_SOFTMAX) {
154  acts_.CopyTimeStepFrom(t, *output, t);
155  }
156  }
157  // Zero all the elements that are in the padding around images that allows
158  // multiple different-sized images to exist in a single array.
159  // acts_ is only used if this is not a softmax op.
160  if (IsTraining() && type_ != NT_SOFTMAX) {
162  }
163  output->ZeroInvalidElements();
164 #if DEBUG_DETAIL > 0
165  tprintf("F Output:%s\n", name_.c_str());
166  output->Print(10);
167 #endif
168  if (debug) DisplayForward(*output);
169 }
170 
171 // Components of Forward so FullyConnected can be reused inside LSTM.
173  const TransposedArray* input_transpose) {
174  // Softmax output is always float, so save the input type.
175  int_mode_ = input.int_mode();
176  if (IsTraining()) {
177  acts_.Resize(input, no_);
178  // Source_ is a transposed copy of input. It isn't needed if provided.
179  external_source_ = input_transpose;
180  if (external_source_ == nullptr) source_t_.ResizeNoInit(ni_, input.Width());
181  }
182 }
183 
184 void FullyConnected::ForwardTimeStep(int t, double* output_line) {
185  if (type_ == NT_TANH) {
186  FuncInplace<GFunc>(no_, output_line);
187  } else if (type_ == NT_LOGISTIC) {
188  FuncInplace<FFunc>(no_, output_line);
189  } else if (type_ == NT_POSCLIP) {
190  FuncInplace<ClipFFunc>(no_, output_line);
191  } else if (type_ == NT_SYMCLIP) {
192  FuncInplace<ClipGFunc>(no_, output_line);
193  } else if (type_ == NT_RELU) {
194  FuncInplace<Relu>(no_, output_line);
195  } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
196  SoftmaxInPlace(no_, output_line);
197  } else if (type_ != NT_LINEAR) {
198  ASSERT_HOST("Invalid fully-connected type!" == nullptr);
199  }
200 }
201 
202 void FullyConnected::ForwardTimeStep(const double* d_input,
203  int t, double* output_line) {
204  // input is copied to source_ line-by-line for cache coherency.
205  if (IsTraining() && external_source_ == nullptr)
206  source_t_.WriteStrided(t, d_input);
207  weights_.MatrixDotVector(d_input, output_line);
208  ForwardTimeStep(t, output_line);
209 }
210 
211 void FullyConnected::ForwardTimeStep(const int8_t* i_input,
212  int t, double* output_line) {
213  // input is copied to source_ line-by-line for cache coherency.
214  weights_.MatrixDotVector(i_input, output_line);
215  ForwardTimeStep(t, output_line);
216 }
217 
218 // Runs backward propagation of errors on the deltas line.
219 // See NetworkCpp for a detailed discussion of the arguments.
220 bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
221  NetworkScratch* scratch,
222  NetworkIO* back_deltas) {
223  if (debug) DisplayBackward(fwd_deltas);
224  back_deltas->Resize(fwd_deltas, ni_);
227  for (int i = 0; i < kNumThreads; ++i) errors[i].Init(no_, scratch);
229  if (needs_to_backprop_) {
231  for (int i = 0; i < kNumThreads; ++i) temp_backprops[i].Init(ni_, scratch);
232  }
233  int width = fwd_deltas.Width();
235  errors_t.Init(no_, width, scratch);
236 #ifdef _OPENMP
237 #pragma omp parallel for num_threads(kNumThreads)
238  for (int t = 0; t < width; ++t) {
239  int thread_id = omp_get_thread_num();
240 #else
241  for (int t = 0; t < width; ++t) {
242  int thread_id = 0;
243 #endif
244  double* backprop = nullptr;
245  if (needs_to_backprop_) backprop = temp_backprops[thread_id];
246  double* curr_errors = errors[thread_id];
247  BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
248  if (backprop != nullptr) {
249  back_deltas->WriteTimeStep(t, backprop);
250  }
251  }
252  FinishBackward(*errors_t.get());
253  if (needs_to_backprop_) {
254  back_deltas->ZeroInvalidElements();
255 #if DEBUG_DETAIL > 0
256  tprintf("F Backprop:%s\n", name_.c_str());
257  back_deltas->Print(10);
258 #endif
259  return true;
260  }
261  return false; // No point going further back.
262 }
263 
264 void FullyConnected::BackwardTimeStep(const NetworkIO& fwd_deltas, int t,
265  double* curr_errors,
266  TransposedArray* errors_t,
267  double* backprop) {
268  if (type_ == NT_TANH)
269  acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
270  else if (type_ == NT_LOGISTIC)
271  acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
272  else if (type_ == NT_POSCLIP)
273  acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
274  else if (type_ == NT_SYMCLIP)
275  acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
276  else if (type_ == NT_RELU)
277  acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
278  else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC ||
279  type_ == NT_LINEAR)
280  fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
281  else
282  ASSERT_HOST("Invalid fully-connected type!" == nullptr);
283  // Generate backprop only if needed by the lower layer.
284  if (backprop != nullptr) weights_.VectorDotMatrix(curr_errors, backprop);
285  errors_t->WriteStrided(t, curr_errors);
286 }
287 
289  if (external_source_ == nullptr)
290  weights_.SumOuterTransposed(errors_t, source_t_, true);
291  else
292  weights_.SumOuterTransposed(errors_t, *external_source_, true);
293 }
294 
295 // Updates the weights using the given learning rate, momentum and adam_beta.
296 // num_samples is used in the adam computation iff use_adam_ is true.
297 void FullyConnected::Update(float learning_rate, float momentum,
298  float adam_beta, int num_samples) {
299  weights_.Update(learning_rate, momentum, adam_beta, num_samples);
300 }
301 
302 // Sums the products of weight updates in *this and other, splitting into
303 // positive (same direction) in *same and negative (different direction) in
304 // *changed.
305 void FullyConnected::CountAlternators(const Network& other, double* same,
306  double* changed) const {
307  ASSERT_HOST(other.type() == type_);
308  const auto* fc = static_cast<const FullyConnected*>(&other);
309  weights_.CountAlternators(fc->weights_, same, changed);
310 }
311 
312 } // namespace tesseract.
tesseract::FullyConnected::external_source_
const TransposedArray * external_source_
Definition: fullyconnected.h:124
tesseract::ClipGPrime
Definition: functions.h:106
tesseract::FullyConnected::InitWeights
int InitWeights(float range, TRand *randomizer) override
Definition: fullyconnected.cpp:76
tesseract::FullyConnected::source_t_
TransposedArray source_t_
Definition: fullyconnected.h:121
tesseract::TS_ENABLED
Definition: network.h:95
tesseract::StaticShape
Definition: static_shape.h:38
tesseract::NT_POSCLIP
Definition: network.h:63
tesseract::WeightMatrix::CountAlternators
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
Definition: weightmatrix.cpp:346
tesseract::NetworkIO::i
const int8_t * i(int t) const
Definition: networkio.h:123
tesseract::LT_CTC
Definition: static_shape.h:31
tesseract::FullyConnected::FinishBackward
void FinishBackward(const TransposedArray &errors_t)
Definition: fullyconnected.cpp:288
tesseract::Network::SetRandomizer
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
tesseract::Network::DisplayForward
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:288
tesseract::FullyConnected::Update
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: fullyconnected.cpp:297
tesseract::FullyConnected::DeSerialize
bool DeSerialize(TFile *fp) override
Definition: fullyconnected.cpp:112
tesseract::NetworkIO::ZeroInvalidElements
void ZeroInvalidElements()
Definition: networkio.cpp:88
tesseract::NetworkIO::int_mode
bool int_mode() const
Definition: networkio.h:127
tesseract::FullyConnected::Backward
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: fullyconnected.cpp:220
tesseract::NT_SOFTMAX_NO_CTC
Definition: network.h:69
tesseract::FullyConnected::SetupForward
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
Definition: fullyconnected.cpp:172
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:87
tesseract::FullyConnected::CountAlternators
void CountAlternators(const Network &other, double *same, double *changed) const override
Definition: fullyconnected.cpp:305
tesseract::WeightMatrix::SumOuterTransposed
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
Definition: weightmatrix.cpp:284
tesseract::NetworkScratch::FloatVec
Definition: networkscratch.h:134
tesseract::FullyConnected::RemapOutputs
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: fullyconnected.cpp:86
tesseract::TransposedArray::WriteStrided
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:39
STRING
Definition: strngs.h:45
tesseract::NetworkIO::Width
int Width() const
Definition: networkio.h:107
tesseract::WeightMatrix::Update
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
Definition: weightmatrix.cpp:314
tesseract::ClipFPrime
Definition: functions.h:79
tesseract::NetworkIO::FuncMultiply
void FuncMultiply(const NetworkIO &v_io, int t, double *product)
Definition: networkio.h:259
tesseract::NetworkScratch
Definition: networkscratch.h:34
tesseract::Network::type
NetworkType type() const
Definition: network.h:112
tesseract::NetworkIO::ResizeFloat
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:52
tesseract::NetworkType
NetworkType
Definition: network.h:43
tesseract::Network::needs_to_backprop_
bool needs_to_backprop_
Definition: network.h:295
tesseract::WeightMatrix::Debug2D
void Debug2D(const char *msg)
Definition: weightmatrix.cpp:377
tesseract::NetworkScratch::GradientStore::Init
void Init(int size1, int size2, NetworkScratch *scratch)
Definition: networkscratch.h:182
tesseract::WeightMatrix::InitBackward
void InitBackward()
Definition: weightmatrix.cpp:153
tesseract::NT_SYMCLIP
Definition: network.h:64
tesseract::Network::TestFlag
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
tesseract::Network::IsTraining
bool IsTraining() const
Definition: network.h:115
tesseract::FullyConnected::SetEnableTraining
void SetEnableTraining(TrainingState state) override
Definition: fullyconnected.cpp:60
tesseract::NetworkScratch::GradientStore::get
TransposedArray * get() const
Definition: networkscratch.h:191
tesseract::LT_NONE
Definition: static_shape.h:30
tesseract::FullyConnected::BackwardTimeStep
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
Definition: fullyconnected.cpp:264
tesseract::Network::name_
STRING name_
Definition: network.h:300
tesseract::SoftmaxInPlace
void SoftmaxInPlace(int n, T *inout)
Definition: functions.h:146
networkscratch.h
tesseract::FullyConnected::Forward
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: fullyconnected.cpp:118
tesseract::LossType
LossType
Definition: static_shape.h:29
tesseract::Network::type_
NetworkType type_
Definition: network.h:293
tesseract::NF_ADAM
Definition: network.h:88
tesseract::FullyConnected::weights_
WeightMatrix weights_
Definition: fullyconnected.h:119
STRING::c_str
const char * c_str() const
Definition: strngs.cpp:192
tesseract::FullyConnected::ForwardTimeStep
void ForwardTimeStep(int t, double *output_line)
Definition: fullyconnected.cpp:184
tesseract::WeightMatrix::DeSerialize
bool DeSerialize(bool training, TFile *fp)
Definition: weightmatrix.cpp:191
GENERIC_2D_ARRAY::ResizeNoInit
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:90
tesseract::FullyConnected::OutputShape
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: fullyconnected.cpp:45
tesseract::WeightMatrix::Serialize
bool Serialize(bool training, TFile *fp) const
Definition: weightmatrix.cpp:172
tesseract::NT_TANH
Definition: network.h:65
tesseract::NetworkIO::Print
void Print(int num) const
Definition: networkio.cpp:366
tesseract::TFile
Definition: serialis.h:75
tesseract::NetworkIO::CopyTimeStepFrom
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
Definition: networkio.cpp:383
tesseract::NetworkIO
Definition: networkio.h:39
tesseract::FullyConnected::FullyConnected
FullyConnected(const STRING &name, int ni, int no, NetworkType type)
Definition: fullyconnected.cpp:38
tesseract::FPrime
Definition: functions.h:69
tesseract::WeightMatrix::ConvertToInt
void ConvertToInt()
Definition: weightmatrix.cpp:125
tesseract::WeightMatrix::InitWeightsFloat
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
Definition: weightmatrix.cpp:76
tesseract::Network::training_
TrainingState training_
Definition: network.h:294
tesseract::FullyConnected::Serialize
bool Serialize(TFile *fp) const override
Definition: fullyconnected.cpp:105
tesseract::NetworkScratch::GradientStore
Definition: networkscratch.h:174
tesseract::TS_RE_ENABLE
Definition: network.h:99
tesseract
Definition: baseapi.h:65
kNumThreads
const int kNumThreads
Definition: fullyconnected.cpp:33
tesseract::FullyConnected::acts_
NetworkIO acts_
Definition: fullyconnected.h:126
tesseract::NetworkIO::ReadTimeStep
void ReadTimeStep(int t, double *output) const
Definition: networkio.cpp:598
tesseract::StaticShape::set_loss_type
void set_loss_type(LossType value)
Definition: static_shape.h:51
GenericVector
Definition: baseapi.h:40
tesseract::LT_SOFTMAX
Definition: static_shape.h:32
tesseract::Network
Definition: network.h:105
tesseract::WeightMatrix::RemapOutputs
int RemapOutputs(const std::vector< int > &code_map)
Definition: weightmatrix.cpp:97
tesseract::FullyConnected::ConvertToInt
void ConvertToInt() override
Definition: fullyconnected.cpp:95
tesseract::NetworkIO::WriteTimeStep
void WriteTimeStep(int t, const double *input)
Definition: networkio.cpp:645
tesseract::Network::num_weights_
int32_t num_weights_
Definition: network.h:299
tesseract::WeightMatrix::MatrixDotVector
void MatrixDotVector(const double *u, double *v) const
Definition: weightmatrix.cpp:243
tesseract::NetworkIO::Resize
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
tesseract::GPrime
Definition: functions.h:96
fullyconnected.h
tesseract::NT_RELU
Definition: network.h:66
tesseract::TransposedArray
Definition: weightmatrix.h:32
tesseract::TrainingState
TrainingState
Definition: network.h:92
tesseract::TS_TEMP_DISABLE
Definition: network.h:97
GenericVector::init_to_size
void init_to_size(int size, const T &t)
Definition: genericvector.h:706
tesseract::WeightMatrix::VectorDotMatrix
void VectorDotMatrix(const double *u, double *v) const
Definition: weightmatrix.cpp:274
functions.h
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
tesseract::ReluPrime
Definition: functions.h:90
tesstrain_utils.type
type
Definition: tesstrain_utils.py:141
tesseract::NT_LINEAR
Definition: network.h:67
tesseract::Network::Serialize
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
tesseract::NT_LOGISTIC
Definition: network.h:62
tesseract::Network::no_
int32_t no_
Definition: network.h:298
tesseract::Network::ni_
int32_t ni_
Definition: network.h:297
tesseract::TRand
Definition: helpers.h:50
tesseract::Network::DisplayBackward
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:299
tesseract::LT_LOGISTIC
Definition: static_shape.h:33
tesseract::FullyConnected::DebugWeights
void DebugWeights() override
Definition: fullyconnected.cpp:100
tesseract::FullyConnected::int_mode_
bool int_mode_
Definition: fullyconnected.h:129
tesseract::StaticShape::set_depth
void set_depth(int value)
Definition: static_shape.h:49
tesseract::NT_SOFTMAX
Definition: network.h:68