tesseract  4.0.0-1-g2a2b
lstm.cpp
Go to the documentation of this file.
1 // File: lstm.cpp
3 // Description: Long-term-short-term-memory Recurrent neural network.
4 // Author: Ray Smith
5 // Created: Wed May 01 17:43:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #include "lstm.h"
20 
21 #ifdef _OPENMP
22 #include <omp.h>
23 #endif
24 #include <cstdio>
25 #include <cstdlib>
26 
27 #if !defined(__GNUC__) && defined(_MSC_VER)
28 #include <intrin.h> // _BitScanReverse
29 #endif
30 
31 #include "fullyconnected.h"
32 #include "functions.h"
33 #include "networkscratch.h"
34 #include "tprintf.h"
35 
36 // Macros for openmp code if it is available, otherwise empty macros.
37 #ifdef _OPENMP
38 #define PARALLEL_IF_OPENMP(__num_threads) \
39  PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
40  PRAGMA(omp sections nowait) { \
41  PRAGMA(omp section) {
42 #define SECTION_IF_OPENMP \
43  } \
44  PRAGMA(omp section) \
45  {
46 
47 #define END_PARALLEL_IF_OPENMP \
48  } \
49  } /* end of sections */ \
50  } /* end of parallel section */
51 
52 // Define the portable PRAGMA macro.
53 #ifdef _MSC_VER // Different _Pragma
54 #define PRAGMA(x) __pragma(x)
55 #else
56 #define PRAGMA(x) _Pragma(#x)
57 #endif // _MSC_VER
58 
59 #else // _OPENMP
60 #define PARALLEL_IF_OPENMP(__num_threads)
61 #define SECTION_IF_OPENMP
62 #define END_PARALLEL_IF_OPENMP
63 #endif // _OPENMP
64 
65 
66 namespace tesseract {
67 
68 // Max absolute value of state_. It is reasonably high to enable the state
69 // to count things.
70 const double kStateClip = 100.0;
71 // Max absolute value of gate_errors (the gradients).
72 const double kErrClip = 1.0f;
73 
74 // Calculate ceil(log2(n)).
75 static inline uint32_t ceil_log2(uint32_t n)
76 {
77  // l2 = (unsigned)log2(n).
78 #if defined(__GNUC__)
79  // Use fast inline assembler code for gcc or clang.
80  uint32_t l2 = 31 - __builtin_clz(n);
81 #elif defined(_MSC_VER)
82  // Use fast intrinsic function for MS compiler.
83  unsigned long l2 = 0;
84  _BitScanReverse(&l2, n);
85 #else
86  if (n == 0) return UINT_MAX;
87  if (n == 1) return 0;
88  uint32_t val = n;
89  uint32_t l2 = 0;
90  while (val > 1) {
91  val >>= 1;
92  l2++;
93  }
94 #endif
95  // Round up if n is not a power of 2.
96  return (n == (1u << l2)) ? l2 : l2 + 1;
97 }
98 
99 LSTM::LSTM(const STRING& name, int ni, int ns, int no, bool two_dimensional,
100  NetworkType type)
101  : Network(type, name, ni, no),
102  na_(ni + ns),
103  ns_(ns),
104  nf_(0),
105  is_2d_(two_dimensional),
106  softmax_(nullptr),
107  input_width_(0) {
108  if (two_dimensional) na_ += ns_;
109  if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
110  nf_ = 0;
111  // networkbuilder ensures this is always true.
112  ASSERT_HOST(no == ns);
113  } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
114  nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : ceil_log2(no_);
115  softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
116  } else {
117  tprintf("%d is invalid type of LSTM!\n", type);
118  ASSERT_HOST(false);
119  }
120  na_ += nf_;
121 }
122 
123 LSTM::~LSTM() { delete softmax_; }
124 
125 // Returns the shape output from the network given an input shape (which may
126 // be partially unknown ie zero).
127 StaticShape LSTM::OutputShape(const StaticShape& input_shape) const {
128  StaticShape result = input_shape;
129  result.set_depth(no_);
130  if (type_ == NT_LSTM_SUMMARY) result.set_width(1);
131  if (softmax_ != nullptr) return softmax_->OutputShape(result);
132  return result;
133 }
134 
135 // Suspends/Enables training by setting the training_ flag. Serialize and
136 // DeSerialize only operate on the run-time data if state is false.
138  if (state == TS_RE_ENABLE) {
139  // Enable only from temp disabled.
141  } else if (state == TS_TEMP_DISABLE) {
142  // Temp disable only from enabled.
143  if (training_ == TS_ENABLED) training_ = state;
144  } else {
145  if (state == TS_ENABLED && training_ != TS_ENABLED) {
146  for (int w = 0; w < WT_COUNT; ++w) {
147  if (w == GFS && !Is2D()) continue;
148  gate_weights_[w].InitBackward();
149  }
150  }
151  training_ = state;
152  }
153  if (softmax_ != nullptr) softmax_->SetEnableTraining(state);
154 }
155 
156 // Sets up the network for training. Initializes weights using weights of
157 // scale `range` picked according to the random number generator `randomizer`.
158 int LSTM::InitWeights(float range, TRand* randomizer) {
159  Network::SetRandomizer(randomizer);
160  num_weights_ = 0;
161  for (int w = 0; w < WT_COUNT; ++w) {
162  if (w == GFS && !Is2D()) continue;
163  num_weights_ += gate_weights_[w].InitWeightsFloat(
164  ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
165  }
166  if (softmax_ != nullptr) {
167  num_weights_ += softmax_->InitWeights(range, randomizer);
168  }
169  return num_weights_;
170 }
171 
172 // Recursively searches the network for softmaxes with old_no outputs,
173 // and remaps their outputs according to code_map. See network.h for details.
174 int LSTM::RemapOutputs(int old_no, const std::vector<int>& code_map) {
175  if (softmax_ != nullptr) {
176  num_weights_ -= softmax_->num_weights();
177  num_weights_ += softmax_->RemapOutputs(old_no, code_map);
178  }
179  return num_weights_;
180 }
181 
182 // Converts a float network to an int network.
184  for (int w = 0; w < WT_COUNT; ++w) {
185  if (w == GFS && !Is2D()) continue;
186  gate_weights_[w].ConvertToInt();
187  }
188  if (softmax_ != nullptr) {
189  softmax_->ConvertToInt();
190  }
191 }
192 
193 // Sets up the network for training using the given weight_range.
195  for (int w = 0; w < WT_COUNT; ++w) {
196  if (w == GFS && !Is2D()) continue;
197  STRING msg = name_;
198  msg.add_str_int(" Gate weights ", w);
199  gate_weights_[w].Debug2D(msg.string());
200  }
201  if (softmax_ != nullptr) {
202  softmax_->DebugWeights();
203  }
204 }
205 
206 // Writes to the given file. Returns false in case of error.
207 bool LSTM::Serialize(TFile* fp) const {
208  if (!Network::Serialize(fp)) return false;
209  if (!fp->Serialize(&na_)) return false;
210  for (int w = 0; w < WT_COUNT; ++w) {
211  if (w == GFS && !Is2D()) continue;
212  if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
213  }
214  if (softmax_ != nullptr && !softmax_->Serialize(fp)) return false;
215  return true;
216 }
217 
218 // Reads from the given file. Returns false in case of error.
219 
221  if (!fp->DeSerialize(&na_)) return false;
222  if (type_ == NT_LSTM_SOFTMAX) {
223  nf_ = no_;
224  } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
225  nf_ = ceil_log2(no_);
226  } else {
227  nf_ = 0;
228  }
229  is_2d_ = false;
230  for (int w = 0; w < WT_COUNT; ++w) {
231  if (w == GFS && !Is2D()) continue;
232  if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) return false;
233  if (w == CI) {
234  ns_ = gate_weights_[CI].NumOutputs();
235  is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
236  }
237  }
238  delete softmax_;
240  softmax_ = static_cast<FullyConnected*>(Network::CreateFromFile(fp));
241  if (softmax_ == nullptr) return false;
242  } else {
243  softmax_ = nullptr;
244  }
245  return true;
246 }
247 
248 // Runs forward propagation of activations on the input line.
249 // See NetworkCpp for a detailed discussion of the arguments.
250 void LSTM::Forward(bool debug, const NetworkIO& input,
251  const TransposedArray* input_transpose,
252  NetworkScratch* scratch, NetworkIO* output) {
253  input_map_ = input.stride_map();
254  input_width_ = input.Width();
255  if (softmax_ != nullptr)
256  output->ResizeFloat(input, no_);
257  else if (type_ == NT_LSTM_SUMMARY)
258  output->ResizeXTo1(input, no_);
259  else
260  output->Resize(input, no_);
261  ResizeForward(input);
262  // Temporary storage of forward computation for each gate.
264  for (int i = 0; i < WT_COUNT; ++i) temp_lines[i].Init(ns_, scratch);
265  // Single timestep buffers for the current/recurrent output and state.
266  NetworkScratch::FloatVec curr_state, curr_output;
267  curr_state.Init(ns_, scratch);
268  ZeroVector<double>(ns_, curr_state);
269  curr_output.Init(ns_, scratch);
270  ZeroVector<double>(ns_, curr_output);
271  // Rotating buffers of width buf_width allow storage of the state and output
272  // for the other dimension, used only when working in true 2D mode. The width
273  // is enough to hold an entire strip of the major direction.
274  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
276  if (Is2D()) {
277  states.init_to_size(buf_width, NetworkScratch::FloatVec());
278  outputs.init_to_size(buf_width, NetworkScratch::FloatVec());
279  for (int i = 0; i < buf_width; ++i) {
280  states[i].Init(ns_, scratch);
281  ZeroVector<double>(ns_, states[i]);
282  outputs[i].Init(ns_, scratch);
283  ZeroVector<double>(ns_, outputs[i]);
284  }
285  }
286  // Used only if a softmax LSTM.
287  NetworkScratch::FloatVec softmax_output;
288  NetworkScratch::IO int_output;
289  if (softmax_ != nullptr) {
290  softmax_output.Init(no_, scratch);
291  ZeroVector<double>(no_, softmax_output);
292  int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_);
293  if (input.int_mode())
294  int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch);
295  softmax_->SetupForward(input, nullptr);
296  }
297  NetworkScratch::FloatVec curr_input;
298  curr_input.Init(na_, scratch);
299  StrideMap::Index src_index(input_map_);
300  // Used only by NT_LSTM_SUMMARY.
301  StrideMap::Index dest_index(output->stride_map());
302  do {
303  int t = src_index.t();
304  // True if there is a valid old state for the 2nd dimension.
305  bool valid_2d = Is2D();
306  if (valid_2d) {
307  StrideMap::Index dim_index(src_index);
308  if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false;
309  }
310  // Index of the 2-D revolving buffers (outputs, states).
311  int mod_t = Modulo(t, buf_width); // Current timestep.
312  // Setup the padded input in source.
313  source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
314  if (softmax_ != nullptr) {
315  source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
316  }
317  source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
318  if (Is2D())
319  source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
320  if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input);
321  // Matrix multiply the inputs with the source.
323  // It looks inefficient to create the threads on each t iteration, but the
324  // alternative of putting the parallel outside the t loop, a single around
325  // the t-loop and then tasks in place of the sections is a *lot* slower.
326  // Cell inputs.
327  if (source_.int_mode())
328  gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
329  else
330  gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
331  FuncInplace<GFunc>(ns_, temp_lines[CI]);
332 
334  // Input Gates.
335  if (source_.int_mode())
336  gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
337  else
338  gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
339  FuncInplace<FFunc>(ns_, temp_lines[GI]);
340 
342  // 1-D forget gates.
343  if (source_.int_mode())
344  gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
345  else
346  gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
347  FuncInplace<FFunc>(ns_, temp_lines[GF1]);
348 
349  // 2-D forget gates.
350  if (Is2D()) {
351  if (source_.int_mode())
352  gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
353  else
354  gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
355  FuncInplace<FFunc>(ns_, temp_lines[GFS]);
356  }
357 
359  // Output gates.
360  if (source_.int_mode())
361  gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
362  else
363  gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
364  FuncInplace<FFunc>(ns_, temp_lines[GO]);
366 
367  // Apply forget gate to state.
368  MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
369  if (Is2D()) {
370  // Max-pool the forget gates (in 2-d) instead of blindly adding.
371  int8_t* which_fg_col = which_fg_[t];
372  memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
373  if (valid_2d) {
374  const double* stepped_state = states[mod_t];
375  for (int i = 0; i < ns_; ++i) {
376  if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
377  curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
378  which_fg_col[i] = 2;
379  }
380  }
381  }
382  }
383  MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
384  // Clip curr_state to a sane range.
385  ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
386  if (IsTraining()) {
387  // Save the gate node values.
388  node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
389  node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
390  node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
391  node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
392  if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
393  }
394  FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
395  if (IsTraining()) state_.WriteTimeStep(t, curr_state);
396  if (softmax_ != nullptr) {
397  if (input.int_mode()) {
398  int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
399  softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
400  } else {
401  softmax_->ForwardTimeStep(curr_output, t, softmax_output);
402  }
403  output->WriteTimeStep(t, softmax_output);
405  CodeInBinary(no_, nf_, softmax_output);
406  }
407  } else if (type_ == NT_LSTM_SUMMARY) {
408  // Output only at the end of a row.
409  if (src_index.IsLast(FD_WIDTH)) {
410  output->WriteTimeStep(dest_index.t(), curr_output);
411  dest_index.Increment();
412  }
413  } else {
414  output->WriteTimeStep(t, curr_output);
415  }
416  // Save states for use by the 2nd dimension only if needed.
417  if (Is2D()) {
418  CopyVector(ns_, curr_state, states[mod_t]);
419  CopyVector(ns_, curr_output, outputs[mod_t]);
420  }
421  // Always zero the states at the end of every row, but only for the major
422  // direction. The 2-D state remains intact.
423  if (src_index.IsLast(FD_WIDTH)) {
424  ZeroVector<double>(ns_, curr_state);
425  ZeroVector<double>(ns_, curr_output);
426  }
427  } while (src_index.Increment());
428 #if DEBUG_DETAIL > 0
429  tprintf("Source:%s\n", name_.string());
430  source_.Print(10);
431  tprintf("State:%s\n", name_.string());
432  state_.Print(10);
433  tprintf("Output:%s\n", name_.string());
434  output->Print(10);
435 #endif
436  if (debug) DisplayForward(*output);
437 }
438 
439 // Runs backward propagation of errors on the deltas line.
440 // See NetworkCpp for a detailed discussion of the arguments.
441 bool LSTM::Backward(bool debug, const NetworkIO& fwd_deltas,
442  NetworkScratch* scratch,
443  NetworkIO* back_deltas) {
444  if (debug) DisplayBackward(fwd_deltas);
445  back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
446  // ======Scratch space.======
447  // Output errors from deltas with recurrence from sourceerr.
448  NetworkScratch::FloatVec outputerr;
449  outputerr.Init(ns_, scratch);
450  // Recurrent error in the state/source.
451  NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
452  curr_stateerr.Init(ns_, scratch);
453  curr_sourceerr.Init(na_, scratch);
454  ZeroVector<double>(ns_, curr_stateerr);
455  ZeroVector<double>(na_, curr_sourceerr);
456  // Errors in the gates.
457  NetworkScratch::FloatVec gate_errors[WT_COUNT];
458  for (int g = 0; g < WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
459  // Rotating buffers of width buf_width allow storage of the recurrent time-
460  // steps used only for true 2-D. Stores one full strip of the major direction.
461  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
462  GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr;
463  if (Is2D()) {
464  stateerr.init_to_size(buf_width, NetworkScratch::FloatVec());
465  sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec());
466  for (int t = 0; t < buf_width; ++t) {
467  stateerr[t].Init(ns_, scratch);
468  sourceerr[t].Init(na_, scratch);
469  ZeroVector<double>(ns_, stateerr[t]);
470  ZeroVector<double>(na_, sourceerr[t]);
471  }
472  }
473  // Parallel-generated sourceerr from each of the gates.
474  NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
475  for (int w = 0; w < WT_COUNT; ++w)
476  sourceerr_temps[w].Init(na_, scratch);
477  int width = input_width_;
478  // Transposed gate errors stored over all timesteps for sum outer.
480  for (int w = 0; w < WT_COUNT; ++w) {
481  gate_errors_t[w].Init(ns_, width, scratch);
482  }
483  // Used only if softmax_ != nullptr.
484  NetworkScratch::FloatVec softmax_errors;
485  NetworkScratch::GradientStore softmax_errors_t;
486  if (softmax_ != nullptr) {
487  softmax_errors.Init(no_, scratch);
488  softmax_errors_t.Init(no_, width, scratch);
489  }
490  double state_clip = Is2D() ? 9.0 : 4.0;
491 #if DEBUG_DETAIL > 1
492  tprintf("fwd_deltas:%s\n", name_.string());
493  fwd_deltas.Print(10);
494 #endif
495  StrideMap::Index dest_index(input_map_);
496  dest_index.InitToLast();
497  // Used only by NT_LSTM_SUMMARY.
498  StrideMap::Index src_index(fwd_deltas.stride_map());
499  src_index.InitToLast();
500  do {
501  int t = dest_index.t();
502  bool at_last_x = dest_index.IsLast(FD_WIDTH);
503  // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
504  // valid if >= 0, which is true if 2d and not on the top/bottom.
505  int up_pos = -1;
506  int down_pos = -1;
507  if (Is2D()) {
508  if (dest_index.index(FD_HEIGHT) > 0) {
509  StrideMap::Index up_index(dest_index);
510  if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t();
511  }
512  if (!dest_index.IsLast(FD_HEIGHT)) {
513  StrideMap::Index down_index(dest_index);
514  if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t();
515  }
516  }
517  // Index of the 2-D revolving buffers (sourceerr, stateerr).
518  int mod_t = Modulo(t, buf_width); // Current timestep.
519  // Zero the state in the major direction only at the end of every row.
520  if (at_last_x) {
521  ZeroVector<double>(na_, curr_sourceerr);
522  ZeroVector<double>(ns_, curr_stateerr);
523  }
524  // Setup the outputerr.
525  if (type_ == NT_LSTM_SUMMARY) {
526  if (dest_index.IsLast(FD_WIDTH)) {
527  fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
528  src_index.Decrement();
529  } else {
530  ZeroVector<double>(ns_, outputerr);
531  }
532  } else if (softmax_ == nullptr) {
533  fwd_deltas.ReadTimeStep(t, outputerr);
534  } else {
535  softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors,
536  softmax_errors_t.get(), outputerr);
537  }
538  if (!at_last_x)
539  AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
540  if (down_pos >= 0)
541  AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
542  // Apply the 1-d forget gates.
543  if (!at_last_x) {
544  const float* next_node_gf1 = node_values_[GF1].f(t + 1);
545  for (int i = 0; i < ns_; ++i) {
546  curr_stateerr[i] *= next_node_gf1[i];
547  }
548  }
549  if (Is2D() && t + 1 < width) {
550  for (int i = 0; i < ns_; ++i) {
551  if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
552  }
553  if (down_pos >= 0) {
554  const float* right_node_gfs = node_values_[GFS].f(down_pos);
555  const double* right_stateerr = stateerr[mod_t];
556  for (int i = 0; i < ns_; ++i) {
557  if (which_fg_[down_pos][i] == 2) {
558  curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
559  }
560  }
561  }
562  }
563  state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr,
564  curr_stateerr);
565  // Clip stateerr_ to a sane range.
566  ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
567 #if DEBUG_DETAIL > 1
568  if (t + 10 > width) {
569  tprintf("t=%d, stateerr=", t);
570  for (int i = 0; i < ns_; ++i)
571  tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i],
572  curr_sourceerr[ni_ + nf_ + i]);
573  tprintf("\n");
574  }
575 #endif
576  // Matrix multiply to get the source errors.
578 
579  // Cell inputs.
580  node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t,
581  curr_stateerr, gate_errors[CI]);
582  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
583  gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
584  gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
585 
587  // Input Gates.
588  node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t,
589  curr_stateerr, gate_errors[GI]);
590  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
591  gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
592  gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
593 
595  // 1-D forget Gates.
596  if (t > 0) {
597  node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
598  gate_errors[GF1]);
599  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
600  gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1],
601  sourceerr_temps[GF1]);
602  } else {
603  memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
604  memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
605  }
606  gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
607 
608  // 2-D forget Gates.
609  if (up_pos >= 0) {
610  node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
611  gate_errors[GFS]);
612  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
613  gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS],
614  sourceerr_temps[GFS]);
615  } else {
616  memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
617  memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
618  }
619  if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
620 
622  // Output gates.
623  state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr,
624  gate_errors[GO]);
625  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
626  gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
627  gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
629 
630  SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
631  sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
632  curr_sourceerr);
633  back_deltas->WriteTimeStep(t, curr_sourceerr);
634  // Save states for use by the 2nd dimension only if needed.
635  if (Is2D()) {
636  CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
637  CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
638  }
639  } while (dest_index.Decrement());
640 #if DEBUG_DETAIL > 2
641  for (int w = 0; w < WT_COUNT; ++w) {
642  tprintf("%s gate errors[%d]\n", name_.string(), w);
643  gate_errors_t[w].get()->PrintUnTransposed(10);
644  }
645 #endif
646  // Transposed source_ used to speed-up SumOuter.
647  NetworkScratch::GradientStore source_t, state_t;
648  source_t.Init(na_, width, scratch);
649  source_.Transpose(source_t.get());
650  state_t.Init(ns_, width, scratch);
651  state_.Transpose(state_t.get());
652 #ifdef _OPENMP
653 #pragma omp parallel for num_threads(GFS) if (!Is2D())
654 #endif
655  for (int w = 0; w < WT_COUNT; ++w) {
656  if (w == GFS && !Is2D()) continue;
657  gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
658  }
659  if (softmax_ != nullptr) {
660  softmax_->FinishBackward(*softmax_errors_t);
661  }
662  return needs_to_backprop_;
663 }
664 
665 // Updates the weights using the given learning rate, momentum and adam_beta.
666 // num_samples is used in the adam computation iff use_adam_ is true.
667 void LSTM::Update(float learning_rate, float momentum, float adam_beta,
668  int num_samples) {
669 #if DEBUG_DETAIL > 3
670  PrintW();
671 #endif
672  for (int w = 0; w < WT_COUNT; ++w) {
673  if (w == GFS && !Is2D()) continue;
674  gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
675  }
676  if (softmax_ != nullptr) {
677  softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
678  }
679 #if DEBUG_DETAIL > 3
680  PrintDW();
681 #endif
682 }
683 
684 // Sums the products of weight updates in *this and other, splitting into
685 // positive (same direction) in *same and negative (different direction) in
686 // *changed.
687 void LSTM::CountAlternators(const Network& other, double* same,
688  double* changed) const {
689  ASSERT_HOST(other.type() == type_);
690  const LSTM* lstm = static_cast<const LSTM*>(&other);
691  for (int w = 0; w < WT_COUNT; ++w) {
692  if (w == GFS && !Is2D()) continue;
693  gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
694  }
695  if (softmax_ != nullptr) {
696  softmax_->CountAlternators(*lstm->softmax_, same, changed);
697  }
698 }
699 
700 // Prints the weights for debug purposes.
701 void LSTM::PrintW() {
702  tprintf("Weight state:%s\n", name_.string());
703  for (int w = 0; w < WT_COUNT; ++w) {
704  if (w == GFS && !Is2D()) continue;
705  tprintf("Gate %d, inputs\n", w);
706  for (int i = 0; i < ni_; ++i) {
707  tprintf("Row %d:", i);
708  for (int s = 0; s < ns_; ++s)
709  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
710  tprintf("\n");
711  }
712  tprintf("Gate %d, outputs\n", w);
713  for (int i = ni_; i < ni_ + ns_; ++i) {
714  tprintf("Row %d:", i - ni_);
715  for (int s = 0; s < ns_; ++s)
716  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
717  tprintf("\n");
718  }
719  tprintf("Gate %d, bias\n", w);
720  for (int s = 0; s < ns_; ++s)
721  tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
722  tprintf("\n");
723  }
724 }
725 
726 // Prints the weight deltas for debug purposes.
728  tprintf("Delta state:%s\n", name_.string());
729  for (int w = 0; w < WT_COUNT; ++w) {
730  if (w == GFS && !Is2D()) continue;
731  tprintf("Gate %d, inputs\n", w);
732  for (int i = 0; i < ni_; ++i) {
733  tprintf("Row %d:", i);
734  for (int s = 0; s < ns_; ++s)
735  tprintf(" %g", gate_weights_[w].GetDW(s, i));
736  tprintf("\n");
737  }
738  tprintf("Gate %d, outputs\n", w);
739  for (int i = ni_; i < ni_ + ns_; ++i) {
740  tprintf("Row %d:", i - ni_);
741  for (int s = 0; s < ns_; ++s)
742  tprintf(" %g", gate_weights_[w].GetDW(s, i));
743  tprintf("\n");
744  }
745  tprintf("Gate %d, bias\n", w);
746  for (int s = 0; s < ns_; ++s)
747  tprintf(" %g", gate_weights_[w].GetDW(s, na_));
748  tprintf("\n");
749  }
750 }
751 
752 // Resizes forward data to cope with an input image of the given width.
753 void LSTM::ResizeForward(const NetworkIO& input) {
754  int rounded_inputs = gate_weights_[CI].RoundInputs(na_);
755  source_.Resize(input, rounded_inputs);
756  which_fg_.ResizeNoInit(input.Width(), ns_);
757  if (IsTraining()) {
758  state_.ResizeFloat(input, ns_);
759  for (int w = 0; w < WT_COUNT; ++w) {
760  if (w == GFS && !Is2D()) continue;
761  node_values_[w].ResizeFloat(input, ns_);
762  }
763  }
764 }
765 
766 
767 } // namespace tesseract.
void Init(int size, NetworkScratch *scratch)
int RoundInputs(int size) const
Definition: weightmatrix.h:93
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
Definition: lstm.cpp:99
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: lstm.cpp:667
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:199
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: lstm.cpp:250
void PrintW()
Definition: lstm.cpp:701
void Print(int num) const
Definition: networkio.cpp:371
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:62
void ConvertToInt() override
void PrintUnTransposed(int num)
Definition: weightmatrix.h:49
int32_t num_weights_
Definition: network.h:305
bool Serialize(TFile *fp) const override
void CodeInBinary(int n, int nf, double *vec)
Definition: functions.h:231
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
bool AddOffset(int offset, FlexDimensions dimension)
Definition: stridemap.cpp:63
const char * string() const
Definition: strngs.cpp:196
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void Init(int size1, int size2, NetworkScratch *scratch)
bool DeSerialize(char *data, size_t count=1)
Definition: serialis.cpp:103
NetworkType
Definition: network.h:43
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: lstm.cpp:174
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:293
int InitWeights(float range, TRand *randomizer) override
int Modulo(int a, int b)
Definition: helpers.h:153
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
void VectorDotMatrix(const double *u, double *v) const
TrainingState training_
Definition: network.h:300
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
virtual ~LSTM()
Definition: lstm.cpp:123
void WriteTimeStep(int t, const double *input)
Definition: networkio.cpp:650
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:186
const int8_t * i(int t) const
Definition: networkio.h:123
NetworkType type_
Definition: network.h:299
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:91
void ClipVector(int n, T lower, T upper, T *vec)
Definition: functions.h:225
bool IsLast(FlexDimensions dimension) const
Definition: stridemap.cpp:38
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
Definition: functions.h:196
int Size(FlexDimensions dimension) const
Definition: stridemap.h:116
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
void set_width(int value)
Definition: static_shape.h:47
void ConvertToInt() override
Definition: lstm.cpp:183
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
Definition: functions.h:201
void Debug2D(const char *msg)
bool Serialize(TFile *fp) const override
Definition: lstm.cpp:207
bool DeSerialize(TFile *fp) override
Definition: lstm.cpp:220
void Transpose(TransposedArray *dest) const
Definition: networkio.cpp:969
TrainingState
Definition: network.h:92
bool Is2D() const
Definition: lstm.h:119
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: lstm.cpp:127
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
int index(FlexDimensions dimension) const
Definition: stridemap.h:60
void init_to_size(int size, const T &t)
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: lstm.cpp:441
void SetEnableTraining(TrainingState state) override
bool needs_to_backprop_
Definition: network.h:301
void SetEnableTraining(TrainingState state) override
Definition: lstm.cpp:137
const double kErrClip
Definition: lstm.cpp:72
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:282
int num_weights() const
Definition: network.h:119
void Func2Multiply3(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:315
StaticShape OutputShape(const StaticShape &input_shape) const override
const double kStateClip
Definition: lstm.cpp:70
void CountAlternators(const Network &other, double *same, double *changed) const override
Definition: lstm.cpp:687
bool Serialize(const char *data, size_t count=1)
Definition: serialis.cpp:147
const StrideMap & stride_map() const
Definition: networkio.h:133
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:37
NetworkType type() const
Definition: network.h:112
void MatrixDotVector(const double *u, double *v) const
void WriteTimeStepPart(int t, int offset, int num_features, const double *input)
Definition: networkio.cpp:656
void add_str_int(const char *str, int number)
Definition: strngs.cpp:379
void ForwardTimeStep(int t, double *output_line)
float * f(int t)
Definition: networkio.h:115
void FinishBackward(const TransposedArray &errors_t)
Definition: strngs.h:45
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
Definition: networkio.cpp:398
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
Definition: functions.h:209
bool int_mode() const
Definition: networkio.h:127
#define SECTION_IF_OPENMP
Definition: lstm.cpp:61
void AccumulateVector(int n, const double *src, double *dest)
Definition: functions.h:191
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:60
bool IsTraining() const
Definition: network.h:115
void set_depth(int value)
Definition: static_shape.h:49
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:299
void CountAlternators(const Network &other, double *same, double *changed) const override
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
Definition: networkio.cpp:51
void ReadTimeStep(int t, double *output) const
Definition: networkio.cpp:603
int InitWeights(float range, TRand *randomizer) override
Definition: lstm.cpp:158
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:40
void DebugWeights() override
Definition: lstm.cpp:194
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
void ResizeXTo1(const NetworkIO &src, int num_features)
Definition: networkio.cpp:75
int Width() const
Definition: networkio.h:107
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:52
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
void PrintDW()
Definition: lstm.cpp:727
#define ASSERT_HOST(x)
Definition: errcode.h:84
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)