tesseract  5.0.0-alpha-619-ge9db
lstm.cpp
Go to the documentation of this file.
1 // File: lstm.cpp
3 // Description: Long-term-short-term-memory Recurrent neural network.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #include "lstm.h"
19 
20 #ifdef _OPENMP
21 #include <omp.h>
22 #endif
23 #include <cstdio>
24 #include <cstdlib>
25 
26 #if !defined(__GNUC__) && defined(_MSC_VER)
27 #include <intrin.h> // _BitScanReverse
28 #endif
29 
30 #include "fullyconnected.h"
31 #include "functions.h"
32 #include "networkscratch.h"
33 #include "tprintf.h"
34 
35 // Macros for openmp code if it is available, otherwise empty macros.
36 #ifdef _OPENMP
37 #define PARALLEL_IF_OPENMP(__num_threads) \
38  PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
39  PRAGMA(omp sections nowait) { \
40  PRAGMA(omp section) {
41 #define SECTION_IF_OPENMP \
42  } \
43  PRAGMA(omp section) \
44  {
45 
46 #define END_PARALLEL_IF_OPENMP \
47  } \
48  } /* end of sections */ \
49  } /* end of parallel section */
50 
51 // Define the portable PRAGMA macro.
52 #ifdef _MSC_VER // Different _Pragma
53 #define PRAGMA(x) __pragma(x)
54 #else
55 #define PRAGMA(x) _Pragma(#x)
56 #endif // _MSC_VER
57 
58 #else // _OPENMP
59 #define PARALLEL_IF_OPENMP(__num_threads)
60 #define SECTION_IF_OPENMP
61 #define END_PARALLEL_IF_OPENMP
62 #endif // _OPENMP
63 
64 
65 namespace tesseract {
66 
67 // Max absolute value of state_. It is reasonably high to enable the state
68 // to count things.
69 const double kStateClip = 100.0;
70 // Max absolute value of gate_errors (the gradients).
71 const double kErrClip = 1.0f;
72 
73 // Calculate ceil(log2(n)).
74 static inline uint32_t ceil_log2(uint32_t n)
75 {
76  // l2 = (unsigned)log2(n).
77 #if defined(__GNUC__)
78  // Use fast inline assembler code for gcc or clang.
79  uint32_t l2 = 31 - __builtin_clz(n);
80 #elif defined(_MSC_VER)
81  // Use fast intrinsic function for MS compiler.
82  unsigned long l2 = 0;
83  _BitScanReverse(&l2, n);
84 #else
85  if (n == 0) return UINT_MAX;
86  if (n == 1) return 0;
87  uint32_t val = n;
88  uint32_t l2 = 0;
89  while (val > 1) {
90  val >>= 1;
91  l2++;
92  }
93 #endif
94  // Round up if n is not a power of 2.
95  return (n == (1u << l2)) ? l2 : l2 + 1;
96 }
97 
98 LSTM::LSTM(const STRING& name, int ni, int ns, int no, bool two_dimensional,
100  : Network(type, name, ni, no),
101  na_(ni + ns),
102  ns_(ns),
103  nf_(0),
104  is_2d_(two_dimensional),
105  softmax_(nullptr),
106  input_width_(0) {
107  if (two_dimensional) na_ += ns_;
108  if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
109  nf_ = 0;
110  // networkbuilder ensures this is always true.
111  ASSERT_HOST(no == ns);
112  } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
113  nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : ceil_log2(no_);
114  softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
115  } else {
116  tprintf("%d is invalid type of LSTM!\n", type);
117  ASSERT_HOST(false);
118  }
119  na_ += nf_;
120 }
121 
122 LSTM::~LSTM() { delete softmax_; }
123 
124 // Returns the shape output from the network given an input shape (which may
125 // be partially unknown ie zero).
126 StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
127  StaticShape result = input_shape;
128  result.set_depth(no_);
129  if (type_ == NT_LSTM_SUMMARY) result.set_width(1);
130  if (softmax_ != nullptr) return softmax_->OutputShape(result);
131  return result;
132 }
133 
134 // Suspends/Enables training by setting the training_ flag. Serialize and
135 // DeSerialize only operate on the run-time data if state is false.
137  if (state == TS_RE_ENABLE) {
138  // Enable only from temp disabled.
140  } else if (state == TS_TEMP_DISABLE) {
141  // Temp disable only from enabled.
142  if (training_ == TS_ENABLED) training_ = state;
143  } else {
144  if (state == TS_ENABLED && training_ != TS_ENABLED) {
145  for (int w = 0; w < WT_COUNT; ++w) {
146  if (w == GFS && !Is2D()) continue;
147  gate_weights_[w].InitBackward();
148  }
149  }
150  training_ = state;
151  }
152  if (softmax_ != nullptr) softmax_->SetEnableTraining(state);
153 }
154 
155 // Sets up the network for training. Initializes weights using weights of
156 // scale `range` picked according to the random number generator `randomizer`.
157 int LSTM::InitWeights(float range, TRand* randomizer) {
158  Network::SetRandomizer(randomizer);
159  num_weights_ = 0;
160  for (int w = 0; w < WT_COUNT; ++w) {
161  if (w == GFS && !Is2D()) continue;
162  num_weights_ += gate_weights_[w].InitWeightsFloat(
163  ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
164  }
165  if (softmax_ != nullptr) {
166  num_weights_ += softmax_->InitWeights(range, randomizer);
167  }
168  return num_weights_;
169 }
170 
171 // Recursively searches the network for softmaxes with old_no outputs,
172 // and remaps their outputs according to code_map. See network.h for details.
173 int LSTM::RemapOutputs(int old_no, const std::vector<int>& code_map) {
174  if (softmax_ != nullptr) {
175  num_weights_ -= softmax_->num_weights();
176  num_weights_ += softmax_->RemapOutputs(old_no, code_map);
177  }
178  return num_weights_;
179 }
180 
181 // Converts a float network to an int network.
183  for (int w = 0; w < WT_COUNT; ++w) {
184  if (w == GFS && !Is2D()) continue;
185  gate_weights_[w].ConvertToInt();
186  }
187  if (softmax_ != nullptr) {
188  softmax_->ConvertToInt();
189  }
190 }
191 
192 // Sets up the network for training using the given weight_range.
194  for (int w = 0; w < WT_COUNT; ++w) {
195  if (w == GFS && !Is2D()) continue;
196  STRING msg = name_;
197  msg.add_str_int(" Gate weights ", w);
198  gate_weights_[w].Debug2D(msg.c_str());
199  }
200  if (softmax_ != nullptr) {
201  softmax_->DebugWeights();
202  }
203 }
204 
205 // Writes to the given file. Returns false in case of error.
206 bool LSTM::Serialize(TFile* fp) const {
207  if (!Network::Serialize(fp)) return false;
208  if (!fp->Serialize(&na_)) return false;
209  for (int w = 0; w < WT_COUNT; ++w) {
210  if (w == GFS && !Is2D()) continue;
211  if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
212  }
213  if (softmax_ != nullptr && !softmax_->Serialize(fp)) return false;
214  return true;
215 }
216 
217 // Reads from the given file. Returns false in case of error.
218 
220  if (!fp->DeSerialize(&na_)) return false;
221  if (type_ == NT_LSTM_SOFTMAX) {
222  nf_ = no_;
223  } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
224  nf_ = ceil_log2(no_);
225  } else {
226  nf_ = 0;
227  }
228  is_2d_ = false;
229  for (int w = 0; w < WT_COUNT; ++w) {
230  if (w == GFS && !Is2D()) continue;
231  if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) return false;
232  if (w == CI) {
233  ns_ = gate_weights_[CI].NumOutputs();
234  is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
235  }
236  }
237  delete softmax_;
239  softmax_ = static_cast<FullyConnected*>(Network::CreateFromFile(fp));
240  if (softmax_ == nullptr) return false;
241  } else {
242  softmax_ = nullptr;
243  }
244  return true;
245 }
246 
247 // Runs forward propagation of activations on the input line.
248 // See NetworkCpp for a detailed discussion of the arguments.
249 void LSTM::Forward(bool debug, const NetworkIO& input,
250  const TransposedArray* input_transpose,
251  NetworkScratch* scratch, NetworkIO* output) {
252  input_map_ = input.stride_map();
253  input_width_ = input.Width();
254  if (softmax_ != nullptr)
255  output->ResizeFloat(input, no_);
256  else if (type_ == NT_LSTM_SUMMARY)
257  output->ResizeXTo1(input, no_);
258  else
259  output->Resize(input, no_);
260  ResizeForward(input);
261  // Temporary storage of forward computation for each gate.
263  for (auto & temp_line : temp_lines) temp_line.Init(ns_, scratch);
264  // Single timestep buffers for the current/recurrent output and state.
265  NetworkScratch::FloatVec curr_state, curr_output;
266  curr_state.Init(ns_, scratch);
267  ZeroVector<double>(ns_, curr_state);
268  curr_output.Init(ns_, scratch);
269  ZeroVector<double>(ns_, curr_output);
270  // Rotating buffers of width buf_width allow storage of the state and output
271  // for the other dimension, used only when working in true 2D mode. The width
272  // is enough to hold an entire strip of the major direction.
273  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
275  if (Is2D()) {
276  states.init_to_size(buf_width, NetworkScratch::FloatVec());
277  outputs.init_to_size(buf_width, NetworkScratch::FloatVec());
278  for (int i = 0; i < buf_width; ++i) {
279  states[i].Init(ns_, scratch);
280  ZeroVector<double>(ns_, states[i]);
281  outputs[i].Init(ns_, scratch);
282  ZeroVector<double>(ns_, outputs[i]);
283  }
284  }
285  // Used only if a softmax LSTM.
286  NetworkScratch::FloatVec softmax_output;
287  NetworkScratch::IO int_output;
288  if (softmax_ != nullptr) {
289  softmax_output.Init(no_, scratch);
290  ZeroVector<double>(no_, softmax_output);
291  int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_);
292  if (input.int_mode())
293  int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch);
294  softmax_->SetupForward(input, nullptr);
295  }
296  NetworkScratch::FloatVec curr_input;
297  curr_input.Init(na_, scratch);
298  StrideMap::Index src_index(input_map_);
299  // Used only by NT_LSTM_SUMMARY.
300  StrideMap::Index dest_index(output->stride_map());
301  do {
302  int t = src_index.t();
303  // True if there is a valid old state for the 2nd dimension.
304  bool valid_2d = Is2D();
305  if (valid_2d) {
306  StrideMap::Index dim_index(src_index);
307  if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false;
308  }
309  // Index of the 2-D revolving buffers (outputs, states).
310  int mod_t = Modulo(t, buf_width); // Current timestep.
311  // Setup the padded input in source.
312  source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
313  if (softmax_ != nullptr) {
314  source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
315  }
316  source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
317  if (Is2D())
318  source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
319  if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input);
320  // Matrix multiply the inputs with the source.
322  // It looks inefficient to create the threads on each t iteration, but the
323  // alternative of putting the parallel outside the t loop, a single around
324  // the t-loop and then tasks in place of the sections is a *lot* slower.
325  // Cell inputs.
326  if (source_.int_mode())
327  gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
328  else
329  gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
330  FuncInplace<GFunc>(ns_, temp_lines[CI]);
331 
333  // Input Gates.
334  if (source_.int_mode())
335  gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
336  else
337  gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
338  FuncInplace<FFunc>(ns_, temp_lines[GI]);
339 
341  // 1-D forget gates.
342  if (source_.int_mode())
343  gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
344  else
345  gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
346  FuncInplace<FFunc>(ns_, temp_lines[GF1]);
347 
348  // 2-D forget gates.
349  if (Is2D()) {
350  if (source_.int_mode())
351  gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
352  else
353  gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
354  FuncInplace<FFunc>(ns_, temp_lines[GFS]);
355  }
356 
358  // Output gates.
359  if (source_.int_mode())
360  gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
361  else
362  gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
363  FuncInplace<FFunc>(ns_, temp_lines[GO]);
365 
366  // Apply forget gate to state.
367  MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
368  if (Is2D()) {
369  // Max-pool the forget gates (in 2-d) instead of blindly adding.
370  int8_t* which_fg_col = which_fg_[t];
371  memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
372  if (valid_2d) {
373  const double* stepped_state = states[mod_t];
374  for (int i = 0; i < ns_; ++i) {
375  if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
376  curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
377  which_fg_col[i] = 2;
378  }
379  }
380  }
381  }
382  MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
383  // Clip curr_state to a sane range.
384  ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
385  if (IsTraining()) {
386  // Save the gate node values.
387  node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
388  node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
389  node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
390  node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
391  if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
392  }
393  FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
394  if (IsTraining()) state_.WriteTimeStep(t, curr_state);
395  if (softmax_ != nullptr) {
396  if (input.int_mode()) {
397  int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
398  softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
399  } else {
400  softmax_->ForwardTimeStep(curr_output, t, softmax_output);
401  }
402  output->WriteTimeStep(t, softmax_output);
404  CodeInBinary(no_, nf_, softmax_output);
405  }
406  } else if (type_ == NT_LSTM_SUMMARY) {
407  // Output only at the end of a row.
408  if (src_index.IsLast(FD_WIDTH)) {
409  output->WriteTimeStep(dest_index.t(), curr_output);
410  dest_index.Increment();
411  }
412  } else {
413  output->WriteTimeStep(t, curr_output);
414  }
415  // Save states for use by the 2nd dimension only if needed.
416  if (Is2D()) {
417  CopyVector(ns_, curr_state, states[mod_t]);
418  CopyVector(ns_, curr_output, outputs[mod_t]);
419  }
420  // Always zero the states at the end of every row, but only for the major
421  // direction. The 2-D state remains intact.
422  if (src_index.IsLast(FD_WIDTH)) {
423  ZeroVector<double>(ns_, curr_state);
424  ZeroVector<double>(ns_, curr_output);
425  }
426  } while (src_index.Increment());
427 #if DEBUG_DETAIL > 0
428  tprintf("Source:%s\n", name_.c_str());
429  source_.Print(10);
430  tprintf("State:%s\n", name_.c_str());
431  state_.Print(10);
432  tprintf("Output:%s\n", name_.c_str());
433  output->Print(10);
434 #endif
435  if (debug) DisplayForward(*output);
436 }
437 
438 // Runs backward propagation of errors on the deltas line.
439 // See NetworkCpp for a detailed discussion of the arguments.
440 bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas,
441  NetworkScratch* scratch,
442  NetworkIO* back_deltas) {
443  if (debug) DisplayBackward(fwd_deltas);
444  back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
445  // ======Scratch space.======
446  // Output errors from deltas with recurrence from sourceerr.
447  NetworkScratch::FloatVec outputerr;
448  outputerr.Init(ns_, scratch);
449  // Recurrent error in the state/source.
450  NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
451  curr_stateerr.Init(ns_, scratch);
452  curr_sourceerr.Init(na_, scratch);
453  ZeroVector<double>(ns_, curr_stateerr);
454  ZeroVector<double>(na_, curr_sourceerr);
455  // Errors in the gates.
456  NetworkScratch::FloatVec gate_errors[WT_COUNT];
457  for (auto & gate_error : gate_errors) gate_error.Init(ns_, scratch);
458  // Rotating buffers of width buf_width allow storage of the recurrent time-
459  // steps used only for true 2-D. Stores one full strip of the major direction.
460  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
461  GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr;
462  if (Is2D()) {
463  stateerr.init_to_size(buf_width, NetworkScratch::FloatVec());
464  sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec());
465  for (int t = 0; t < buf_width; ++t) {
466  stateerr[t].Init(ns_, scratch);
467  sourceerr[t].Init(na_, scratch);
468  ZeroVector<double>(ns_, stateerr[t]);
469  ZeroVector<double>(na_, sourceerr[t]);
470  }
471  }
472  // Parallel-generated sourceerr from each of the gates.
473  NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
474  for (auto & sourceerr_temp : sourceerr_temps)
475  sourceerr_temp.Init(na_, scratch);
476  int width = input_width_;
477  // Transposed gate errors stored over all timesteps for sum outer.
479  for (auto & w : gate_errors_t) {
480  w.Init(ns_, width, scratch);
481  }
482  // Used only if softmax_ != nullptr.
483  NetworkScratch::FloatVec softmax_errors;
484  NetworkScratch::GradientStore softmax_errors_t;
485  if (softmax_ != nullptr) {
486  softmax_errors.Init(no_, scratch);
487  softmax_errors_t.Init(no_, width, scratch);
488  }
489  double state_clip = Is2D() ? 9.0 : 4.0;
490 #if DEBUG_DETAIL > 1
491  tprintf("fwd_deltas:%s\n", name_.c_str());
492  fwd_deltas.Print(10);
493 #endif
494  StrideMap::Index dest_index(input_map_);
495  dest_index.InitToLast();
496  // Used only by NT_LSTM_SUMMARY.
497  StrideMap::Index src_index(fwd_deltas.stride_map());
498  src_index.InitToLast();
499  do {
500  int t = dest_index.t();
501  bool at_last_x = dest_index.IsLast(FD_WIDTH);
502  // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
503  // valid if >= 0, which is true if 2d and not on the top/bottom.
504  int up_pos = -1;
505  int down_pos = -1;
506  if (Is2D()) {
507  if (dest_index.index(FD_HEIGHT) > 0) {
508  StrideMap::Index up_index(dest_index);
509  if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t();
510  }
511  if (!dest_index.IsLast(FD_HEIGHT)) {
512  StrideMap::Index down_index(dest_index);
513  if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t();
514  }
515  }
516  // Index of the 2-D revolving buffers (sourceerr, stateerr).
517  int mod_t = Modulo(t, buf_width); // Current timestep.
518  // Zero the state in the major direction only at the end of every row.
519  if (at_last_x) {
520  ZeroVector<double>(na_, curr_sourceerr);
521  ZeroVector<double>(ns_, curr_stateerr);
522  }
523  // Setup the outputerr.
524  if (type_ == NT_LSTM_SUMMARY) {
525  if (dest_index.IsLast(FD_WIDTH)) {
526  fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
527  src_index.Decrement();
528  } else {
529  ZeroVector<double>(ns_, outputerr);
530  }
531  } else if (softmax_ == nullptr) {
532  fwd_deltas.ReadTimeStep(t, outputerr);
533  } else {
534  softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors,
535  softmax_errors_t.get(), outputerr);
536  }
537  if (!at_last_x)
538  AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
539  if (down_pos >= 0)
540  AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
541  // Apply the 1-d forget gates.
542  if (!at_last_x) {
543  const float* next_node_gf1 = node_values_[GF1].f(t + 1);
544  for (int i = 0; i < ns_; ++i) {
545  curr_stateerr[i] *= next_node_gf1[i];
546  }
547  }
548  if (Is2D() && t + 1 < width) {
549  for (int i = 0; i < ns_; ++i) {
550  if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
551  }
552  if (down_pos >= 0) {
553  const float* right_node_gfs = node_values_[GFS].f(down_pos);
554  const double* right_stateerr = stateerr[mod_t];
555  for (int i = 0; i < ns_; ++i) {
556  if (which_fg_[down_pos][i] == 2) {
557  curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
558  }
559  }
560  }
561  }
562  state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr,
563  curr_stateerr);
564  // Clip stateerr_ to a sane range.
565  ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
566 #if DEBUG_DETAIL > 1
567  if (t + 10 > width) {
568  tprintf("t=%d, stateerr=", t);
569  for (int i = 0; i < ns_; ++i)
570  tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i],
571  curr_sourceerr[ni_ + nf_ + i]);
572  tprintf("\n");
573  }
574 #endif
575  // Matrix multiply to get the source errors.
577 
578  // Cell inputs.
579  node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t,
580  curr_stateerr, gate_errors[CI]);
581  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
582  gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
583  gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
584 
586  // Input Gates.
587  node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t,
588  curr_stateerr, gate_errors[GI]);
589  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
590  gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
591  gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
592 
594  // 1-D forget Gates.
595  if (t > 0) {
596  node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
597  gate_errors[GF1]);
598  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
599  gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1],
600  sourceerr_temps[GF1]);
601  } else {
602  memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
603  memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
604  }
605  gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
606 
607  // 2-D forget Gates.
608  if (up_pos >= 0) {
609  node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
610  gate_errors[GFS]);
611  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
612  gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS],
613  sourceerr_temps[GFS]);
614  } else {
615  memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
616  memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
617  }
618  if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
619 
621  // Output gates.
622  state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr,
623  gate_errors[GO]);
624  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
625  gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
626  gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
628 
629  SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
630  sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
631  curr_sourceerr);
632  back_deltas->WriteTimeStep(t, curr_sourceerr);
633  // Save states for use by the 2nd dimension only if needed.
634  if (Is2D()) {
635  CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
636  CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
637  }
638  } while (dest_index.Decrement());
639 #if DEBUG_DETAIL > 2
640  for (int w = 0; w < WT_COUNT; ++w) {
641  tprintf("%s gate errors[%d]\n", name_.c_str(), w);
642  gate_errors_t[w].get()->PrintUnTransposed(10);
643  }
644 #endif
645  // Transposed source_ used to speed-up SumOuter.
646  NetworkScratch::GradientStore source_t, state_t;
647  source_t.Init(na_, width, scratch);
648  source_.Transpose(source_t.get());
649  state_t.Init(ns_, width, scratch);
650  state_.Transpose(state_t.get());
651 #ifdef _OPENMP
652 #pragma omp parallel for num_threads(GFS) if (!Is2D())
653 #endif
654  for (int w = 0; w < WT_COUNT; ++w) {
655  if (w == GFS && !Is2D()) continue;
656  gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
657  }
658  if (softmax_ != nullptr) {
659  softmax_->FinishBackward(*softmax_errors_t);
660  }
661  return needs_to_backprop_;
662 }
663 
664 // Updates the weights using the given learning rate, momentum and adam_beta.
665 // num_samples is used in the adam computation iff use_adam_ is true.
666 void LSTM::Update(float learning_rate, float momentum, float adam_beta,
667  int num_samples) {
668 #if DEBUG_DETAIL > 3
669  PrintW();
670 #endif
671  for (int w = 0; w < WT_COUNT; ++w) {
672  if (w == GFS && !Is2D()) continue;
673  gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
674  }
675  if (softmax_ != nullptr) {
676  softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
677  }
678 #if DEBUG_DETAIL > 3
679  PrintDW();
680 #endif
681 }
682 
683 // Sums the products of weight updates in *this and other, splitting into
684 // positive (same direction) in *same and negative (different direction) in
685 // *changed.
686 void LSTM::CountAlternators(const Network& other, double* same,
687  double* changed) const {
688  ASSERT_HOST(other.type() == type_);
689  const LSTM* lstm = static_cast<const LSTM*>(&other);
690  for (int w = 0; w < WT_COUNT; ++w) {
691  if (w == GFS && !Is2D()) continue;
692  gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
693  }
694  if (softmax_ != nullptr) {
695  softmax_->CountAlternators(*lstm->softmax_, same, changed);
696  }
697 }
698 
699 // Prints the weights for debug purposes.
700 void LSTM::PrintW() {
701  tprintf("Weight state:%s\n", name_.c_str());
702  for (int w = 0; w < WT_COUNT; ++w) {
703  if (w == GFS && !Is2D()) continue;
704  tprintf("Gate %d, inputs\n", w);
705  for (int i = 0; i < ni_; ++i) {
706  tprintf("Row %d:", i);
707  for (int s = 0; s < ns_; ++s)
708  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
709  tprintf("\n");
710  }
711  tprintf("Gate %d, outputs\n", w);
712  for (int i = ni_; i < ni_ + ns_; ++i) {
713  tprintf("Row %d:", i - ni_);
714  for (int s = 0; s < ns_; ++s)
715  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
716  tprintf("\n");
717  }
718  tprintf("Gate %d, bias\n", w);
719  for (int s = 0; s < ns_; ++s)
720  tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
721  tprintf("\n");
722  }
723 }
724 
725 // Prints the weight deltas for debug purposes.
727  tprintf("Delta state:%s\n", name_.c_str());
728  for (int w = 0; w < WT_COUNT; ++w) {
729  if (w == GFS && !Is2D()) continue;
730  tprintf("Gate %d, inputs\n", w);
731  for (int i = 0; i < ni_; ++i) {
732  tprintf("Row %d:", i);
733  for (int s = 0; s < ns_; ++s)
734  tprintf(" %g", gate_weights_[w].GetDW(s, i));
735  tprintf("\n");
736  }
737  tprintf("Gate %d, outputs\n", w);
738  for (int i = ni_; i < ni_ + ns_; ++i) {
739  tprintf("Row %d:", i - ni_);
740  for (int s = 0; s < ns_; ++s)
741  tprintf(" %g", gate_weights_[w].GetDW(s, i));
742  tprintf("\n");
743  }
744  tprintf("Gate %d, bias\n", w);
745  for (int s = 0; s < ns_; ++s)
746  tprintf(" %g", gate_weights_[w].GetDW(s, na_));
747  tprintf("\n");
748  }
749 }
750 
751 // Resizes forward data to cope with an input image of the given width.
752 void LSTM::ResizeForward(const NetworkIO& input) {
753  int rounded_inputs = gate_weights_[CI].RoundInputs(na_);
754  source_.Resize(input, rounded_inputs);
755  which_fg_.ResizeNoInit(input.Width(), ns_);
756  if (IsTraining()) {
757  state_.ResizeFloat(input, ns_);
758  for (int w = 0; w < WT_COUNT; ++w) {
759  if (w == GFS && !Is2D()) continue;
760  node_values_[w].ResizeFloat(input, ns_);
761  }
762  }
763 }
764 
765 
766 } // namespace tesseract.
tesseract::FullyConnected::InitWeights
int InitWeights(float range, TRand *randomizer) override
Definition: fullyconnected.cpp:76
tesseract::TS_ENABLED
Definition: network.h:95
tesseract::LSTM::PrintW
void PrintW()
Definition: lstm.cpp:700
tesseract::StaticShape
Definition: static_shape.h:38
tesseract::WeightMatrix::CountAlternators
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
Definition: weightmatrix.cpp:346
tesseract::LSTM::Serialize
bool Serialize(TFile *fp) const override
Definition: lstm.cpp:206
tesseract::LSTM::GO
Definition: lstm.h:37
tesseract::StrideMap::Size
int Size(FlexDimensions dimension) const
Definition: stridemap.h:114
tesseract::NetworkIO::i
const int8_t * i(int t) const
Definition: networkio.h:123
tesseract::FullyConnected::FinishBackward
void FinishBackward(const TransposedArray &errors_t)
Definition: fullyconnected.cpp:288
tesseract::Network::SetRandomizer
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
STRING::add_str_int
void add_str_int(const char *str, int number)
Definition: strngs.cpp:370
tesseract::Network::DisplayForward
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:288
SECTION_IF_OPENMP
#define SECTION_IF_OPENMP
Definition: lstm.cpp:60
tesseract::FullyConnected::Update
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: fullyconnected.cpp:297
tesseract::NetworkIO::int_mode
bool int_mode() const
Definition: networkio.h:127
tesseract::HPrime
Definition: functions.h:114
tesseract::CopyVector
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:169
tesseract::FullyConnected::SetupForward
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
Definition: fullyconnected.cpp:172
tesseract::StrideMap::Index
Definition: stridemap.h:44
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:87
tesseract::FullyConnected::CountAlternators
void CountAlternators(const Network &other, double *same, double *changed) const override
Definition: fullyconnected.cpp:305
tesseract::LSTM::GF1
Definition: lstm.h:36
tesseract::WeightMatrix::SumOuterTransposed
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
Definition: weightmatrix.cpp:284
tesseract::NetworkScratch::FloatVec
Definition: networkscratch.h:134
tesseract::FullyConnected::RemapOutputs
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: fullyconnected.cpp:86
tesseract::StrideMap::Index::IsLast
bool IsLast(FlexDimensions dimension) const
Definition: stridemap.cpp:37
STRING
Definition: strngs.h:45
tesseract::NetworkIO::Width
int Width() const
Definition: networkio.h:107
tesseract::MultiplyVectorsInPlace
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
Definition: functions.h:179
tesseract::WeightMatrix::Update
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
Definition: weightmatrix.cpp:314
tesseract::NetworkScratch
Definition: networkscratch.h:34
tesseract::kErrClip
const double kErrClip
Definition: lstm.cpp:71
tesseract::NetworkIO::stride_map
const StrideMap & stride_map() const
Definition: networkio.h:133
tesseract::Network::type
NetworkType type() const
Definition: network.h:112
tesseract::MultiplyAccumulate
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
Definition: functions.h:184
tesseract::NetworkIO::FuncMultiply3Add
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:299
tesseract::LSTM::WT_COUNT
Definition: lstm.h:40
tesseract::NetworkIO::ResizeFloat
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:52
tesseract::NetworkType
NetworkType
Definition: network.h:43
tesseract::Network::needs_to_backprop_
bool needs_to_backprop_
Definition: network.h:295
tesseract::NT_LSTM
Definition: network.h:60
tesseract::WeightMatrix::Debug2D
void Debug2D(const char *msg)
Definition: weightmatrix.cpp:377
tesseract::NetworkScratch::GradientStore::Init
void Init(int size1, int size2, NetworkScratch *scratch)
Definition: networkscratch.h:182
tesseract::LSTM::LSTM
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
Definition: lstm.cpp:98
tesseract::WeightMatrix::InitBackward
void InitBackward()
Definition: weightmatrix.cpp:153
tesseract::LSTM::ConvertToInt
void ConvertToInt() override
Definition: lstm.cpp:182
tesseract::NetworkIO::ResizeXTo1
void ResizeXTo1(const NetworkIO &src, int num_features)
Definition: networkio.cpp:70
tesseract::LSTM
Definition: lstm.h:28
tesseract::Network::TestFlag
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
tesseract::Network::IsTraining
bool IsTraining() const
Definition: network.h:115
tesseract::LSTM::DeSerialize
bool DeSerialize(TFile *fp) override
Definition: lstm.cpp:219
tesseract::FullyConnected::SetEnableTraining
void SetEnableTraining(TrainingState state) override
Definition: fullyconnected.cpp:60
tesseract::NetworkScratch::GradientStore::get
TransposedArray * get() const
Definition: networkscratch.h:191
tesseract::NetworkIO::Transpose
void Transpose(TransposedArray *dest) const
Definition: networkio.cpp:964
tesseract::FullyConnected::BackwardTimeStep
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
Definition: fullyconnected.cpp:264
tesseract::Network::name_
STRING name_
Definition: network.h:300
networkscratch.h
tesseract::StaticShape::set_width
void set_width(int value)
Definition: static_shape.h:47
tesseract::NetworkScratch::IO::Resize2d
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
Definition: networkscratch.h:87
tesseract::FullyConnected
Definition: fullyconnected.h:28
tesseract::Network::CreateFromFile
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:187
tesseract::StrideMap::Index::InitToLast
void InitToLast()
Definition: stridemap.h:65
tesseract::Network::type_
NetworkType type_
Definition: network.h:293
tesseract::NF_ADAM
Definition: network.h:88
tesseract::StrideMap::Index::index
int index(FlexDimensions dimension) const
Definition: stridemap.h:58
tesseract::WeightMatrix::RoundInputs
int RoundInputs(int size) const
Definition: weightmatrix.h:92
STRING::c_str
const char * c_str() const
Definition: strngs.cpp:192
tesseract::CodeInBinary
void CodeInBinary(int n, int nf, double *vec)
Definition: functions.h:214
tesseract::StrideMap::Index::AddOffset
bool AddOffset(int offset, FlexDimensions dimension)
Definition: stridemap.cpp:62
tesseract::TFile::DeSerialize
bool DeSerialize(char *data, size_t count=1)
Definition: serialis.cpp:117
tesseract::FullyConnected::ForwardTimeStep
void ForwardTimeStep(int t, double *output_line)
Definition: fullyconnected.cpp:184
tesseract::TFile::Serialize
bool Serialize(const char *data, size_t count=1)
Definition: serialis.cpp:161
tesseract::StrideMap::Index::t
int t() const
Definition: stridemap.h:57
tesseract::FD_WIDTH
Definition: stridemap.h:35
PARALLEL_IF_OPENMP
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:59
GENERIC_2D_ARRAY::ResizeNoInit
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:90
tesseract::FullyConnected::OutputShape
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: fullyconnected.cpp:45
tesseract::NetworkIO::f
float * f(int t)
Definition: networkio.h:115
tesseract::NetworkIO::Print
void Print(int num) const
Definition: networkio.cpp:366
tesseract::LSTM::GI
Definition: lstm.h:35
tesseract::TFile
Definition: serialis.h:75
tesseract::NetworkIO
Definition: networkio.h:39
tesseract::FPrime
Definition: functions.h:69
tesseract::WeightMatrix::ConvertToInt
void ConvertToInt()
Definition: weightmatrix.cpp:125
tesseract::WeightMatrix::InitWeightsFloat
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
Definition: weightmatrix.cpp:76
tesseract::Network::training_
TrainingState training_
Definition: network.h:294
tesseract::NetworkScratch::FloatVec::Init
void Init(int size, NetworkScratch *scratch)
Definition: networkscratch.h:147
tesseract::FullyConnected::Serialize
bool Serialize(TFile *fp) const override
Definition: fullyconnected.cpp:105
tesseract::NetworkScratch::GradientStore
Definition: networkscratch.h:174
tesseract::LSTM::GFS
Definition: lstm.h:38
tesseract::TS_RE_ENABLE
Definition: network.h:99
tesseract
Definition: baseapi.h:65
tesseract::LSTM::OutputShape
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: lstm.cpp:126
tesseract::LSTM::PrintDW
void PrintDW()
Definition: lstm.cpp:726
lstm.h
tesseract::NetworkIO::ReadTimeStep
void ReadTimeStep(int t, double *output) const
Definition: networkio.cpp:598
tesseract::LSTM::SetEnableTraining
void SetEnableTraining(TrainingState state) override
Definition: lstm.cpp:136
tprintf.h
tesseract::LSTM::Update
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: lstm.cpp:666
GenericVector
Definition: baseapi.h:40
tesseract::NetworkScratch::IO
Definition: networkscratch.h:51
tesseract::NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:76
tesseract::LSTM::CountAlternators
void CountAlternators(const Network &other, double *same, double *changed) const override
Definition: lstm.cpp:686
tesseract::NetworkIO::CopyTimeStepGeneral
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
Definition: networkio.cpp:393
tesseract::NetworkIO::ResizeToMap
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
Definition: networkio.cpp:46
tesseract::Network
Definition: network.h:105
tesseract::ClipVector
void ClipVector(int n, T lower, T upper, T *vec)
Definition: functions.h:208
tesseract::FullyConnected::ConvertToInt
void ConvertToInt() override
Definition: fullyconnected.cpp:95
tesseract::NetworkIO::WriteTimeStep
void WriteTimeStep(int t, const double *input)
Definition: networkio.cpp:645
tesseract::Network::num_weights_
int32_t num_weights_
Definition: network.h:299
tesseract::AccumulateVector
void AccumulateVector(int n, const double *src, double *dest)
Definition: functions.h:174
tesseract::StrideMap::Index::Increment
bool Increment()
Definition: stridemap.cpp:70
tesseract::WeightMatrix::MatrixDotVector
void MatrixDotVector(const double *u, double *v) const
Definition: weightmatrix.cpp:243
tesseract::HFunc
Definition: functions.h:111
tesseract::LSTM::Backward
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: lstm.cpp:440
tesseract::NetworkIO::Func2Multiply3
void Func2Multiply3(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:315
tesseract::NetworkIO::Resize
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
tesseract::GPrime
Definition: functions.h:96
tesseract::FD_HEIGHT
Definition: stridemap.h:34
fullyconnected.h
tesseract::LSTM::CI
Definition: lstm.h:34
tesseract::TransposedArray
Definition: weightmatrix.h:32
tesseract::TrainingState
TrainingState
Definition: network.h:92
tesseract::TS_TEMP_DISABLE
Definition: network.h:97
tesseract::NT_LSTM_SOFTMAX
Definition: network.h:75
GenericVector::init_to_size
void init_to_size(int size, const T &t)
Definition: genericvector.h:706
tesseract::WeightMatrix::VectorDotMatrix
void VectorDotMatrix(const double *u, double *v) const
Definition: weightmatrix.cpp:274
tesseract::NT_LSTM_SUMMARY
Definition: network.h:61
functions.h
Modulo
int Modulo(int a, int b)
Definition: helpers.h:156
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
tesseract::StrideMap::Index::Decrement
bool Decrement()
Definition: stridemap.cpp:87
tesseract::kStateClip
const double kStateClip
Definition: lstm.cpp:69
tesseract::SumVectors
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
Definition: functions.h:192
tesseract::LSTM::RemapOutputs
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: lstm.cpp:173
tesstrain_utils.type
type
Definition: tesstrain_utils.py:141
tesseract::WeightMatrix::NumOutputs
int NumOutputs() const
Definition: weightmatrix.h:101
END_PARALLEL_IF_OPENMP
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:61
tesseract::Network::Serialize
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
tesseract::LSTM::DebugWeights
void DebugWeights() override
Definition: lstm.cpp:193
tesseract::LSTM::InitWeights
int InitWeights(float range, TRand *randomizer) override
Definition: lstm.cpp:157
tesseract::LSTM::~LSTM
~LSTM() override
Definition: lstm.cpp:122
tesseract::Network::no_
int32_t no_
Definition: network.h:298
tesseract::Network::ni_
int32_t ni_
Definition: network.h:297
tesseract::NetworkIO::WriteTimeStepPart
void WriteTimeStepPart(int t, int offset, int num_features, const double *input)
Definition: networkio.cpp:651
tesseract::TRand
Definition: helpers.h:50
tesseract::Network::num_weights
int num_weights() const
Definition: network.h:119
tesseract::Network::DisplayBackward
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:299
tesseract::FullyConnected::DebugWeights
void DebugWeights() override
Definition: fullyconnected.cpp:100
tesseract::StaticShape::set_depth
void set_depth(int value)
Definition: static_shape.h:49
tesseract::LSTM::Forward
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: lstm.cpp:249
tesseract::NT_SOFTMAX
Definition: network.h:68
tesseract::LSTM::Is2D
bool Is2D() const
Definition: lstm.h:119