tesseract  4.0.0-1-g2a2b
tesseract::LSTM Class Reference

#include <lstm.h>

Inheritance diagram for tesseract::LSTM:
tesseract::Network

Public Types

enum  WeightType {
  CI, GI, GF1, GO,
  GFS, WT_COUNT
}
 

Public Member Functions

 LSTM (const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
 
virtual ~LSTM ()
 
StaticShape OutputShape (const StaticShape &input_shape) const override
 
STRING spec () const override
 
void SetEnableTraining (TrainingState state) override
 
int InitWeights (float range, TRand *randomizer) override
 
int RemapOutputs (int old_no, const std::vector< int > &code_map) override
 
void ConvertToInt () override
 
void DebugWeights () override
 
bool Serialize (TFile *fp) const override
 
bool DeSerialize (TFile *fp) override
 
void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
 
bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
 
void Update (float learning_rate, float momentum, float adam_beta, int num_samples) override
 
void CountAlternators (const Network &other, double *same, double *changed) const override
 
void PrintW ()
 
void PrintDW ()
 
bool Is2D () const
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()=default
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
const STRINGname () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetNetworkFlags (uint32_t flags)
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor (int factor)
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Additional Inherited Members

- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 
- Protected Member Functions inherited from tesseract::Network
double Random (double range)
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
int32_t network_flags_
 
int32_t ni_
 
int32_t no_
 
int32_t num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 
- Static Protected Attributes inherited from tesseract::Network
static char const *const kTypeNames [NT_COUNT]
 

Detailed Description

Definition at line 28 of file lstm.h.

Member Enumeration Documentation

◆ WeightType

Enumerator
CI 
GI 
GF1 
GO 
GFS 
WT_COUNT 

Definition at line 33 of file lstm.h.

33  {
34  CI, // Cell Inputs.
35  GI, // Gate at the input.
36  GF1, // Forget gate at the memory (1-d or looking back 1 timestep).
37  GO, // Gate at the output.
38  GFS, // Forget gate at the memory, looking back in the other dimension.
39 
40  WT_COUNT // Number of WeightTypes.
41  };

Constructor & Destructor Documentation

◆ LSTM()

tesseract::LSTM::LSTM ( const STRING name,
int  num_inputs,
int  num_states,
int  num_outputs,
bool  two_dimensional,
NetworkType  type 
)

Definition at line 99 of file lstm.cpp.

101  : Network(type, name, ni, no),
102  na_(ni + ns),
103  ns_(ns),
104  nf_(0),
105  is_2d_(two_dimensional),
106  softmax_(nullptr),
107  input_width_(0) {
108  if (two_dimensional) na_ += ns_;
109  if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
110  nf_ = 0;
111  // networkbuilder ensures this is always true.
112  ASSERT_HOST(no == ns);
113  } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
114  nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : ceil_log2(no_);
115  softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
116  } else {
117  tprintf("%d is invalid type of LSTM!\n", type);
118  ASSERT_HOST(false);
119  }
120  na_ += nf_;
121 }
NetworkType type_
Definition: network.h:299
const STRING & name() const
Definition: network.h:138
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:37
NetworkType type() const
Definition: network.h:112
#define ASSERT_HOST(x)
Definition: errcode.h:84

◆ ~LSTM()

tesseract::LSTM::~LSTM ( )
virtual

Definition at line 123 of file lstm.cpp.

123 { delete softmax_; }

Member Function Documentation

◆ Backward()

bool tesseract::LSTM::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 441 of file lstm.cpp.

443  {
444  if (debug) DisplayBackward(fwd_deltas);
445  back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
446  // ======Scratch space.======
447  // Output errors from deltas with recurrence from sourceerr.
448  NetworkScratch::FloatVec outputerr;
449  outputerr.Init(ns_, scratch);
450  // Recurrent error in the state/source.
451  NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
452  curr_stateerr.Init(ns_, scratch);
453  curr_sourceerr.Init(na_, scratch);
454  ZeroVector<double>(ns_, curr_stateerr);
455  ZeroVector<double>(na_, curr_sourceerr);
456  // Errors in the gates.
457  NetworkScratch::FloatVec gate_errors[WT_COUNT];
458  for (int g = 0; g < WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
459  // Rotating buffers of width buf_width allow storage of the recurrent time-
460  // steps used only for true 2-D. Stores one full strip of the major direction.
461  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
462  GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr;
463  if (Is2D()) {
464  stateerr.init_to_size(buf_width, NetworkScratch::FloatVec());
465  sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec());
466  for (int t = 0; t < buf_width; ++t) {
467  stateerr[t].Init(ns_, scratch);
468  sourceerr[t].Init(na_, scratch);
469  ZeroVector<double>(ns_, stateerr[t]);
470  ZeroVector<double>(na_, sourceerr[t]);
471  }
472  }
473  // Parallel-generated sourceerr from each of the gates.
474  NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
475  for (int w = 0; w < WT_COUNT; ++w)
476  sourceerr_temps[w].Init(na_, scratch);
477  int width = input_width_;
478  // Transposed gate errors stored over all timesteps for sum outer.
479  NetworkScratch::GradientStore gate_errors_t[WT_COUNT];
480  for (int w = 0; w < WT_COUNT; ++w) {
481  gate_errors_t[w].Init(ns_, width, scratch);
482  }
483  // Used only if softmax_ != nullptr.
484  NetworkScratch::FloatVec softmax_errors;
485  NetworkScratch::GradientStore softmax_errors_t;
486  if (softmax_ != nullptr) {
487  softmax_errors.Init(no_, scratch);
488  softmax_errors_t.Init(no_, width, scratch);
489  }
490  double state_clip = Is2D() ? 9.0 : 4.0;
491 #if DEBUG_DETAIL > 1
492  tprintf("fwd_deltas:%s\n", name_.string());
493  fwd_deltas.Print(10);
494 #endif
495  StrideMap::Index dest_index(input_map_);
496  dest_index.InitToLast();
497  // Used only by NT_LSTM_SUMMARY.
498  StrideMap::Index src_index(fwd_deltas.stride_map());
499  src_index.InitToLast();
500  do {
501  int t = dest_index.t();
502  bool at_last_x = dest_index.IsLast(FD_WIDTH);
503  // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
504  // valid if >= 0, which is true if 2d and not on the top/bottom.
505  int up_pos = -1;
506  int down_pos = -1;
507  if (Is2D()) {
508  if (dest_index.index(FD_HEIGHT) > 0) {
509  StrideMap::Index up_index(dest_index);
510  if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t();
511  }
512  if (!dest_index.IsLast(FD_HEIGHT)) {
513  StrideMap::Index down_index(dest_index);
514  if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t();
515  }
516  }
517  // Index of the 2-D revolving buffers (sourceerr, stateerr).
518  int mod_t = Modulo(t, buf_width); // Current timestep.
519  // Zero the state in the major direction only at the end of every row.
520  if (at_last_x) {
521  ZeroVector<double>(na_, curr_sourceerr);
522  ZeroVector<double>(ns_, curr_stateerr);
523  }
524  // Setup the outputerr.
525  if (type_ == NT_LSTM_SUMMARY) {
526  if (dest_index.IsLast(FD_WIDTH)) {
527  fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
528  src_index.Decrement();
529  } else {
530  ZeroVector<double>(ns_, outputerr);
531  }
532  } else if (softmax_ == nullptr) {
533  fwd_deltas.ReadTimeStep(t, outputerr);
534  } else {
535  softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors,
536  softmax_errors_t.get(), outputerr);
537  }
538  if (!at_last_x)
539  AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
540  if (down_pos >= 0)
541  AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
542  // Apply the 1-d forget gates.
543  if (!at_last_x) {
544  const float* next_node_gf1 = node_values_[GF1].f(t + 1);
545  for (int i = 0; i < ns_; ++i) {
546  curr_stateerr[i] *= next_node_gf1[i];
547  }
548  }
549  if (Is2D() && t + 1 < width) {
550  for (int i = 0; i < ns_; ++i) {
551  if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
552  }
553  if (down_pos >= 0) {
554  const float* right_node_gfs = node_values_[GFS].f(down_pos);
555  const double* right_stateerr = stateerr[mod_t];
556  for (int i = 0; i < ns_; ++i) {
557  if (which_fg_[down_pos][i] == 2) {
558  curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
559  }
560  }
561  }
562  }
563  state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr,
564  curr_stateerr);
565  // Clip stateerr_ to a sane range.
566  ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
567 #if DEBUG_DETAIL > 1
568  if (t + 10 > width) {
569  tprintf("t=%d, stateerr=", t);
570  for (int i = 0; i < ns_; ++i)
571  tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i],
572  curr_sourceerr[ni_ + nf_ + i]);
573  tprintf("\n");
574  }
575 #endif
576  // Matrix multiply to get the source errors.
578 
579  // Cell inputs.
580  node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t,
581  curr_stateerr, gate_errors[CI]);
582  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
583  gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
584  gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
585 
587  // Input Gates.
588  node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t,
589  curr_stateerr, gate_errors[GI]);
590  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
591  gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
592  gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
593 
595  // 1-D forget Gates.
596  if (t > 0) {
597  node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
598  gate_errors[GF1]);
599  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
600  gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1],
601  sourceerr_temps[GF1]);
602  } else {
603  memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
604  memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
605  }
606  gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
607 
608  // 2-D forget Gates.
609  if (up_pos >= 0) {
610  node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
611  gate_errors[GFS]);
612  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
613  gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS],
614  sourceerr_temps[GFS]);
615  } else {
616  memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
617  memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
618  }
619  if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
620 
622  // Output gates.
623  state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr,
624  gate_errors[GO]);
625  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
626  gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
627  gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
629 
630  SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
631  sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
632  curr_sourceerr);
633  back_deltas->WriteTimeStep(t, curr_sourceerr);
634  // Save states for use by the 2nd dimension only if needed.
635  if (Is2D()) {
636  CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
637  CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
638  }
639  } while (dest_index.Decrement());
640 #if DEBUG_DETAIL > 2
641  for (int w = 0; w < WT_COUNT; ++w) {
642  tprintf("%s gate errors[%d]\n", name_.string(), w);
643  gate_errors_t[w].get()->PrintUnTransposed(10);
644  }
645 #endif
646  // Transposed source_ used to speed-up SumOuter.
647  NetworkScratch::GradientStore source_t, state_t;
648  source_t.Init(na_, width, scratch);
649  source_.Transpose(source_t.get());
650  state_t.Init(ns_, width, scratch);
651  state_.Transpose(state_t.get());
652 #ifdef _OPENMP
653 #pragma omp parallel for num_threads(GFS) if (!Is2D())
654 #endif
655  for (int w = 0; w < WT_COUNT; ++w) {
656  if (w == GFS && !Is2D()) continue;
657  gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
658  }
659  if (softmax_ != nullptr) {
660  softmax_->FinishBackward(*softmax_errors_t);
661  }
662  return needs_to_backprop_;
663 }
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:62
const char * string() const
Definition: strngs.cpp:196
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:293
int Modulo(int a, int b)
Definition: helpers.h:153
void VectorDotMatrix(const double *u, double *v) const
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:186
NetworkType type_
Definition: network.h:299
void ClipVector(int n, T lower, T upper, T *vec)
Definition: functions.h:225
int Size(FlexDimensions dimension) const
Definition: stridemap.h:116
void Transpose(TransposedArray *dest) const
Definition: networkio.cpp:969
bool Is2D() const
Definition: lstm.h:119
void init_to_size(int size, const T &t)
bool needs_to_backprop_
Definition: network.h:301
const double kErrClip
Definition: lstm.cpp:72
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:37
float * f(int t)
Definition: networkio.h:115
void FinishBackward(const TransposedArray &errors_t)
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
Definition: functions.h:209
#define SECTION_IF_OPENMP
Definition: lstm.cpp:61
void AccumulateVector(int n, const double *src, double *dest)
Definition: functions.h:191
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:60
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:299
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)

◆ ConvertToInt()

void tesseract::LSTM::ConvertToInt ( )
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 183 of file lstm.cpp.

183  {
184  for (int w = 0; w < WT_COUNT; ++w) {
185  if (w == GFS && !Is2D()) continue;
186  gate_weights_[w].ConvertToInt();
187  }
188  if (softmax_ != nullptr) {
189  softmax_->ConvertToInt();
190  }
191 }
void ConvertToInt() override
bool Is2D() const
Definition: lstm.h:119

◆ CountAlternators()

void tesseract::LSTM::CountAlternators ( const Network other,
double *  same,
double *  changed 
) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 687 of file lstm.cpp.

688  {
689  ASSERT_HOST(other.type() == type_);
690  const LSTM* lstm = static_cast<const LSTM*>(&other);
691  for (int w = 0; w < WT_COUNT; ++w) {
692  if (w == GFS && !Is2D()) continue;
693  gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
694  }
695  if (softmax_ != nullptr) {
696  softmax_->CountAlternators(*lstm->softmax_, same, changed);
697  }
698 }
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
Definition: lstm.cpp:99
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
NetworkType type_
Definition: network.h:299
bool Is2D() const
Definition: lstm.h:119
void CountAlternators(const Network &other, double *same, double *changed) const override
#define ASSERT_HOST(x)
Definition: errcode.h:84

◆ DebugWeights()

void tesseract::LSTM::DebugWeights ( )
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 194 of file lstm.cpp.

194  {
195  for (int w = 0; w < WT_COUNT; ++w) {
196  if (w == GFS && !Is2D()) continue;
197  STRING msg = name_;
198  msg.add_str_int(" Gate weights ", w);
199  gate_weights_[w].Debug2D(msg.string());
200  }
201  if (softmax_ != nullptr) {
202  softmax_->DebugWeights();
203  }
204 }
const char * string() const
Definition: strngs.cpp:196
void Debug2D(const char *msg)
bool Is2D() const
Definition: lstm.h:119
void add_str_int(const char *str, int number)
Definition: strngs.cpp:379
Definition: strngs.h:45

◆ DeSerialize()

bool tesseract::LSTM::DeSerialize ( TFile fp)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 220 of file lstm.cpp.

220  {
221  if (!fp->DeSerialize(&na_)) return false;
222  if (type_ == NT_LSTM_SOFTMAX) {
223  nf_ = no_;
224  } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
225  nf_ = ceil_log2(no_);
226  } else {
227  nf_ = 0;
228  }
229  is_2d_ = false;
230  for (int w = 0; w < WT_COUNT; ++w) {
231  if (w == GFS && !Is2D()) continue;
232  if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) return false;
233  if (w == CI) {
234  ns_ = gate_weights_[CI].NumOutputs();
235  is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
236  }
237  }
238  delete softmax_;
240  softmax_ = static_cast<FullyConnected*>(Network::CreateFromFile(fp));
241  if (softmax_ == nullptr) return false;
242  } else {
243  softmax_ = nullptr;
244  }
245  return true;
246 }
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:199
NetworkType type_
Definition: network.h:299
bool DeSerialize(TFile *fp) override
Definition: lstm.cpp:220
bool Is2D() const
Definition: lstm.h:119
bool IsTraining() const
Definition: network.h:115

◆ Forward()

void tesseract::LSTM::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 250 of file lstm.cpp.

252  {
253  input_map_ = input.stride_map();
254  input_width_ = input.Width();
255  if (softmax_ != nullptr)
256  output->ResizeFloat(input, no_);
257  else if (type_ == NT_LSTM_SUMMARY)
258  output->ResizeXTo1(input, no_);
259  else
260  output->Resize(input, no_);
261  ResizeForward(input);
262  // Temporary storage of forward computation for each gate.
263  NetworkScratch::FloatVec temp_lines[WT_COUNT];
264  for (int i = 0; i < WT_COUNT; ++i) temp_lines[i].Init(ns_, scratch);
265  // Single timestep buffers for the current/recurrent output and state.
266  NetworkScratch::FloatVec curr_state, curr_output;
267  curr_state.Init(ns_, scratch);
268  ZeroVector<double>(ns_, curr_state);
269  curr_output.Init(ns_, scratch);
270  ZeroVector<double>(ns_, curr_output);
271  // Rotating buffers of width buf_width allow storage of the state and output
272  // for the other dimension, used only when working in true 2D mode. The width
273  // is enough to hold an entire strip of the major direction.
274  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
276  if (Is2D()) {
277  states.init_to_size(buf_width, NetworkScratch::FloatVec());
278  outputs.init_to_size(buf_width, NetworkScratch::FloatVec());
279  for (int i = 0; i < buf_width; ++i) {
280  states[i].Init(ns_, scratch);
281  ZeroVector<double>(ns_, states[i]);
282  outputs[i].Init(ns_, scratch);
283  ZeroVector<double>(ns_, outputs[i]);
284  }
285  }
286  // Used only if a softmax LSTM.
287  NetworkScratch::FloatVec softmax_output;
288  NetworkScratch::IO int_output;
289  if (softmax_ != nullptr) {
290  softmax_output.Init(no_, scratch);
291  ZeroVector<double>(no_, softmax_output);
292  int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_);
293  if (input.int_mode())
294  int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch);
295  softmax_->SetupForward(input, nullptr);
296  }
297  NetworkScratch::FloatVec curr_input;
298  curr_input.Init(na_, scratch);
299  StrideMap::Index src_index(input_map_);
300  // Used only by NT_LSTM_SUMMARY.
301  StrideMap::Index dest_index(output->stride_map());
302  do {
303  int t = src_index.t();
304  // True if there is a valid old state for the 2nd dimension.
305  bool valid_2d = Is2D();
306  if (valid_2d) {
307  StrideMap::Index dim_index(src_index);
308  if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false;
309  }
310  // Index of the 2-D revolving buffers (outputs, states).
311  int mod_t = Modulo(t, buf_width); // Current timestep.
312  // Setup the padded input in source.
313  source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
314  if (softmax_ != nullptr) {
315  source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
316  }
317  source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
318  if (Is2D())
319  source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
320  if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input);
321  // Matrix multiply the inputs with the source.
323  // It looks inefficient to create the threads on each t iteration, but the
324  // alternative of putting the parallel outside the t loop, a single around
325  // the t-loop and then tasks in place of the sections is a *lot* slower.
326  // Cell inputs.
327  if (source_.int_mode())
328  gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
329  else
330  gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
331  FuncInplace<GFunc>(ns_, temp_lines[CI]);
332 
334  // Input Gates.
335  if (source_.int_mode())
336  gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
337  else
338  gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
339  FuncInplace<FFunc>(ns_, temp_lines[GI]);
340 
342  // 1-D forget gates.
343  if (source_.int_mode())
344  gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
345  else
346  gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
347  FuncInplace<FFunc>(ns_, temp_lines[GF1]);
348 
349  // 2-D forget gates.
350  if (Is2D()) {
351  if (source_.int_mode())
352  gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
353  else
354  gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
355  FuncInplace<FFunc>(ns_, temp_lines[GFS]);
356  }
357 
359  // Output gates.
360  if (source_.int_mode())
361  gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
362  else
363  gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
364  FuncInplace<FFunc>(ns_, temp_lines[GO]);
366 
367  // Apply forget gate to state.
368  MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
369  if (Is2D()) {
370  // Max-pool the forget gates (in 2-d) instead of blindly adding.
371  int8_t* which_fg_col = which_fg_[t];
372  memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
373  if (valid_2d) {
374  const double* stepped_state = states[mod_t];
375  for (int i = 0; i < ns_; ++i) {
376  if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
377  curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
378  which_fg_col[i] = 2;
379  }
380  }
381  }
382  }
383  MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
384  // Clip curr_state to a sane range.
385  ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
386  if (IsTraining()) {
387  // Save the gate node values.
388  node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
389  node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
390  node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
391  node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
392  if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
393  }
394  FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
395  if (IsTraining()) state_.WriteTimeStep(t, curr_state);
396  if (softmax_ != nullptr) {
397  if (input.int_mode()) {
398  int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
399  softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
400  } else {
401  softmax_->ForwardTimeStep(curr_output, t, softmax_output);
402  }
403  output->WriteTimeStep(t, softmax_output);
405  CodeInBinary(no_, nf_, softmax_output);
406  }
407  } else if (type_ == NT_LSTM_SUMMARY) {
408  // Output only at the end of a row.
409  if (src_index.IsLast(FD_WIDTH)) {
410  output->WriteTimeStep(dest_index.t(), curr_output);
411  dest_index.Increment();
412  }
413  } else {
414  output->WriteTimeStep(t, curr_output);
415  }
416  // Save states for use by the 2nd dimension only if needed.
417  if (Is2D()) {
418  CopyVector(ns_, curr_state, states[mod_t]);
419  CopyVector(ns_, curr_output, outputs[mod_t]);
420  }
421  // Always zero the states at the end of every row, but only for the major
422  // direction. The 2-D state remains intact.
423  if (src_index.IsLast(FD_WIDTH)) {
424  ZeroVector<double>(ns_, curr_state);
425  ZeroVector<double>(ns_, curr_output);
426  }
427  } while (src_index.Increment());
428 #if DEBUG_DETAIL > 0
429  tprintf("Source:%s\n", name_.string());
430  source_.Print(10);
431  tprintf("State:%s\n", name_.string());
432  state_.Print(10);
433  tprintf("Output:%s\n", name_.string());
434  output->Print(10);
435 #endif
436  if (debug) DisplayForward(*output);
437 }
int RoundInputs(int size) const
Definition: weightmatrix.h:93
void Print(int num) const
Definition: networkio.cpp:371
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:62
void CodeInBinary(int n, int nf, double *vec)
Definition: functions.h:231
const char * string() const
Definition: strngs.cpp:196
int Modulo(int a, int b)
Definition: helpers.h:153
void WriteTimeStep(int t, const double *input)
Definition: networkio.cpp:650
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:186
const int8_t * i(int t) const
Definition: networkio.h:123
NetworkType type_
Definition: network.h:299
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
Definition: functions.h:196
int Size(FlexDimensions dimension) const
Definition: stridemap.h:116
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
Definition: functions.h:201
bool Is2D() const
Definition: lstm.h:119
void init_to_size(int size, const T &t)
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:282
const double kStateClip
Definition: lstm.cpp:70
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:37
void MatrixDotVector(const double *u, double *v) const
void WriteTimeStepPart(int t, int offset, int num_features, const double *input)
Definition: networkio.cpp:656
void ForwardTimeStep(int t, double *output_line)
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
Definition: networkio.cpp:398
bool int_mode() const
Definition: networkio.h:127
#define SECTION_IF_OPENMP
Definition: lstm.cpp:61
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:60
bool IsTraining() const
Definition: network.h:115
void ReadTimeStep(int t, double *output) const
Definition: networkio.cpp:603
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)

◆ InitWeights()

int tesseract::LSTM::InitWeights ( float  range,
TRand randomizer 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 158 of file lstm.cpp.

158  {
159  Network::SetRandomizer(randomizer);
160  num_weights_ = 0;
161  for (int w = 0; w < WT_COUNT; ++w) {
162  if (w == GFS && !Is2D()) continue;
163  num_weights_ += gate_weights_[w].InitWeightsFloat(
164  ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
165  }
166  if (softmax_ != nullptr) {
167  num_weights_ += softmax_->InitWeights(range, randomizer);
168  }
169  return num_weights_;
170 }
int32_t num_weights_
Definition: network.h:305
int InitWeights(float range, TRand *randomizer) override
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
bool Is2D() const
Definition: lstm.h:119
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144

◆ Is2D()

bool tesseract::LSTM::Is2D ( ) const
inline

Definition at line 119 of file lstm.h.

119  {
120  return is_2d_;
121  }

◆ OutputShape()

StaticShape tesseract::LSTM::OutputShape ( const StaticShape input_shape) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 127 of file lstm.cpp.

127  {
128  StaticShape result = input_shape;
129  result.set_depth(no_);
130  if (type_ == NT_LSTM_SUMMARY) result.set_width(1);
131  if (softmax_ != nullptr) return softmax_->OutputShape(result);
132  return result;
133 }
NetworkType type_
Definition: network.h:299
StaticShape OutputShape(const StaticShape &input_shape) const override

◆ PrintDW()

void tesseract::LSTM::PrintDW ( )

Definition at line 727 of file lstm.cpp.

727  {
728  tprintf("Delta state:%s\n", name_.string());
729  for (int w = 0; w < WT_COUNT; ++w) {
730  if (w == GFS && !Is2D()) continue;
731  tprintf("Gate %d, inputs\n", w);
732  for (int i = 0; i < ni_; ++i) {
733  tprintf("Row %d:", i);
734  for (int s = 0; s < ns_; ++s)
735  tprintf(" %g", gate_weights_[w].GetDW(s, i));
736  tprintf("\n");
737  }
738  tprintf("Gate %d, outputs\n", w);
739  for (int i = ni_; i < ni_ + ns_; ++i) {
740  tprintf("Row %d:", i - ni_);
741  for (int s = 0; s < ns_; ++s)
742  tprintf(" %g", gate_weights_[w].GetDW(s, i));
743  tprintf("\n");
744  }
745  tprintf("Gate %d, bias\n", w);
746  for (int s = 0; s < ns_; ++s)
747  tprintf(" %g", gate_weights_[w].GetDW(s, na_));
748  tprintf("\n");
749  }
750 }
const char * string() const
Definition: strngs.cpp:196
bool Is2D() const
Definition: lstm.h:119
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:37

◆ PrintW()

void tesseract::LSTM::PrintW ( )

Definition at line 701 of file lstm.cpp.

701  {
702  tprintf("Weight state:%s\n", name_.string());
703  for (int w = 0; w < WT_COUNT; ++w) {
704  if (w == GFS && !Is2D()) continue;
705  tprintf("Gate %d, inputs\n", w);
706  for (int i = 0; i < ni_; ++i) {
707  tprintf("Row %d:", i);
708  for (int s = 0; s < ns_; ++s)
709  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
710  tprintf("\n");
711  }
712  tprintf("Gate %d, outputs\n", w);
713  for (int i = ni_; i < ni_ + ns_; ++i) {
714  tprintf("Row %d:", i - ni_);
715  for (int s = 0; s < ns_; ++s)
716  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
717  tprintf("\n");
718  }
719  tprintf("Gate %d, bias\n", w);
720  for (int s = 0; s < ns_; ++s)
721  tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
722  tprintf("\n");
723  }
724 }
const char * string() const
Definition: strngs.cpp:196
bool Is2D() const
Definition: lstm.h:119
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:37

◆ RemapOutputs()

int tesseract::LSTM::RemapOutputs ( int  old_no,
const std::vector< int > &  code_map 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 174 of file lstm.cpp.

174  {
175  if (softmax_ != nullptr) {
176  num_weights_ -= softmax_->num_weights();
177  num_weights_ += softmax_->RemapOutputs(old_no, code_map);
178  }
179  return num_weights_;
180 }
int32_t num_weights_
Definition: network.h:305
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
int num_weights() const
Definition: network.h:119

◆ Serialize()

bool tesseract::LSTM::Serialize ( TFile fp) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 207 of file lstm.cpp.

207  {
208  if (!Network::Serialize(fp)) return false;
209  if (!fp->Serialize(&na_)) return false;
210  for (int w = 0; w < WT_COUNT; ++w) {
211  if (w == GFS && !Is2D()) continue;
212  if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
213  }
214  if (softmax_ != nullptr && !softmax_->Serialize(fp)) return false;
215  return true;
216 }
bool Serialize(TFile *fp) const override
bool Serialize(TFile *fp) const override
Definition: lstm.cpp:207
bool Is2D() const
Definition: lstm.h:119
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
bool IsTraining() const
Definition: network.h:115

◆ SetEnableTraining()

void tesseract::LSTM::SetEnableTraining ( TrainingState  state)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 137 of file lstm.cpp.

137  {
138  if (state == TS_RE_ENABLE) {
139  // Enable only from temp disabled.
141  } else if (state == TS_TEMP_DISABLE) {
142  // Temp disable only from enabled.
143  if (training_ == TS_ENABLED) training_ = state;
144  } else {
145  if (state == TS_ENABLED && training_ != TS_ENABLED) {
146  for (int w = 0; w < WT_COUNT; ++w) {
147  if (w == GFS && !Is2D()) continue;
148  gate_weights_[w].InitBackward();
149  }
150  }
151  training_ = state;
152  }
153  if (softmax_ != nullptr) softmax_->SetEnableTraining(state);
154 }
TrainingState training_
Definition: network.h:300
bool Is2D() const
Definition: lstm.h:119
void SetEnableTraining(TrainingState state) override

◆ spec()

STRING tesseract::LSTM::spec ( ) const
inlineoverridevirtual

Reimplemented from tesseract::Network.

Definition at line 58 of file lstm.h.

58  {
59  STRING spec;
60  if (type_ == NT_LSTM)
61  spec.add_str_int("Lfx", ns_);
62  else if (type_ == NT_LSTM_SUMMARY)
63  spec.add_str_int("Lfxs", ns_);
64  else if (type_ == NT_LSTM_SOFTMAX)
65  spec.add_str_int("LS", ns_);
66  else if (type_ == NT_LSTM_SOFTMAX_ENCODED)
67  spec.add_str_int("LE", ns_);
68  if (softmax_ != nullptr) spec += softmax_->spec();
69  return spec;
70  }
NetworkType type_
Definition: network.h:299
void add_str_int(const char *str, int number)
Definition: strngs.cpp:379
STRING spec() const override
Definition: strngs.h:45
STRING spec() const override
Definition: lstm.h:58

◆ Update()

void tesseract::LSTM::Update ( float  learning_rate,
float  momentum,
float  adam_beta,
int  num_samples 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 667 of file lstm.cpp.

668  {
669 #if DEBUG_DETAIL > 3
670  PrintW();
671 #endif
672  for (int w = 0; w < WT_COUNT; ++w) {
673  if (w == GFS && !Is2D()) continue;
674  gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
675  }
676  if (softmax_ != nullptr) {
677  softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
678  }
679 #if DEBUG_DETAIL > 3
680  PrintDW();
681 #endif
682 }
void PrintW()
Definition: lstm.cpp:701
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
bool Is2D() const
Definition: lstm.h:119
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
void PrintDW()
Definition: lstm.cpp:727

The documentation for this class was generated from the following files: