tesseract  5.0.0-alpha-619-ge9db
tesseract::LSTM Class Reference

#include <lstm.h>

Inheritance diagram for tesseract::LSTM:
tesseract::Network

Public Types

enum  WeightType {
  CI, GI, GF1, GO,
  GFS, WT_COUNT
}
 

Public Member Functions

 LSTM (const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
 
 ~LSTM () override
 
StaticShape OutputShape (const StaticShape &input_shape) const override
 
STRING spec () const override
 
void SetEnableTraining (TrainingState state) override
 
int InitWeights (float range, TRand *randomizer) override
 
int RemapOutputs (int old_no, const std::vector< int > &code_map) override
 
void ConvertToInt () override
 
void DebugWeights () override
 
bool Serialize (TFile *fp) const override
 
bool DeSerialize (TFile *fp) override
 
void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
 
bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
 
void Update (float learning_rate, float momentum, float adam_beta, int num_samples) override
 
void CountAlternators (const Network &other, double *same, double *changed) const override
 
void PrintW ()
 
void PrintDW ()
 
bool Is2D () const
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()=default
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
const STRINGname () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetNetworkFlags (uint32_t flags)
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor (int factor)
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Additional Inherited Members

- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 
- Protected Member Functions inherited from tesseract::Network
double Random (double range)
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
int32_t network_flags_
 
int32_t ni_
 
int32_t no_
 
int32_t num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 

Detailed Description

Definition at line 28 of file lstm.h.

Member Enumeration Documentation

◆ WeightType

Enumerator
CI 
GI 
GF1 
GO 
GFS 
WT_COUNT 

Definition at line 33 of file lstm.h.

33  {
34  CI, // Cell Inputs.
35  GI, // Gate at the input.
36  GF1, // Forget gate at the memory (1-d or looking back 1 timestep).
37  GO, // Gate at the output.
38  GFS, // Forget gate at the memory, looking back in the other dimension.
39 
40  WT_COUNT // Number of WeightTypes.
41  };

Constructor & Destructor Documentation

◆ LSTM()

tesseract::LSTM::LSTM ( const STRING name,
int  num_inputs,
int  num_states,
int  num_outputs,
bool  two_dimensional,
NetworkType  type 
)

Definition at line 98 of file lstm.cpp.

100  : Network(type, name, ni, no),
101  na_(ni + ns),
102  ns_(ns),
103  nf_(0),
104  is_2d_(two_dimensional),
105  softmax_(nullptr),
106  input_width_(0) {
107  if (two_dimensional) na_ += ns_;
108  if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
109  nf_ = 0;
110  // networkbuilder ensures this is always true.
111  ASSERT_HOST(no == ns);
112  } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
113  nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : ceil_log2(no_);
114  softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
115  } else {
116  tprintf("%d is invalid type of LSTM!\n", type);
117  ASSERT_HOST(false);
118  }
119  na_ += nf_;
120 }

◆ ~LSTM()

tesseract::LSTM::~LSTM ( )
override

Definition at line 122 of file lstm.cpp.

122 { delete softmax_; }

Member Function Documentation

◆ Backward()

bool tesseract::LSTM::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
overridevirtual

Implements tesseract::Network.

Definition at line 440 of file lstm.cpp.

442  {
443  if (debug) DisplayBackward(fwd_deltas);
444  back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
445  // ======Scratch space.======
446  // Output errors from deltas with recurrence from sourceerr.
447  NetworkScratch::FloatVec outputerr;
448  outputerr.Init(ns_, scratch);
449  // Recurrent error in the state/source.
450  NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
451  curr_stateerr.Init(ns_, scratch);
452  curr_sourceerr.Init(na_, scratch);
453  ZeroVector<double>(ns_, curr_stateerr);
454  ZeroVector<double>(na_, curr_sourceerr);
455  // Errors in the gates.
456  NetworkScratch::FloatVec gate_errors[WT_COUNT];
457  for (auto & gate_error : gate_errors) gate_error.Init(ns_, scratch);
458  // Rotating buffers of width buf_width allow storage of the recurrent time-
459  // steps used only for true 2-D. Stores one full strip of the major direction.
460  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
461  GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr;
462  if (Is2D()) {
463  stateerr.init_to_size(buf_width, NetworkScratch::FloatVec());
464  sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec());
465  for (int t = 0; t < buf_width; ++t) {
466  stateerr[t].Init(ns_, scratch);
467  sourceerr[t].Init(na_, scratch);
468  ZeroVector<double>(ns_, stateerr[t]);
469  ZeroVector<double>(na_, sourceerr[t]);
470  }
471  }
472  // Parallel-generated sourceerr from each of the gates.
473  NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
474  for (auto & sourceerr_temp : sourceerr_temps)
475  sourceerr_temp.Init(na_, scratch);
476  int width = input_width_;
477  // Transposed gate errors stored over all timesteps for sum outer.
478  NetworkScratch::GradientStore gate_errors_t[WT_COUNT];
479  for (auto & w : gate_errors_t) {
480  w.Init(ns_, width, scratch);
481  }
482  // Used only if softmax_ != nullptr.
483  NetworkScratch::FloatVec softmax_errors;
484  NetworkScratch::GradientStore softmax_errors_t;
485  if (softmax_ != nullptr) {
486  softmax_errors.Init(no_, scratch);
487  softmax_errors_t.Init(no_, width, scratch);
488  }
489  double state_clip = Is2D() ? 9.0 : 4.0;
490 #if DEBUG_DETAIL > 1
491  tprintf("fwd_deltas:%s\n", name_.c_str());
492  fwd_deltas.Print(10);
493 #endif
494  StrideMap::Index dest_index(input_map_);
495  dest_index.InitToLast();
496  // Used only by NT_LSTM_SUMMARY.
497  StrideMap::Index src_index(fwd_deltas.stride_map());
498  src_index.InitToLast();
499  do {
500  int t = dest_index.t();
501  bool at_last_x = dest_index.IsLast(FD_WIDTH);
502  // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
503  // valid if >= 0, which is true if 2d and not on the top/bottom.
504  int up_pos = -1;
505  int down_pos = -1;
506  if (Is2D()) {
507  if (dest_index.index(FD_HEIGHT) > 0) {
508  StrideMap::Index up_index(dest_index);
509  if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t();
510  }
511  if (!dest_index.IsLast(FD_HEIGHT)) {
512  StrideMap::Index down_index(dest_index);
513  if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t();
514  }
515  }
516  // Index of the 2-D revolving buffers (sourceerr, stateerr).
517  int mod_t = Modulo(t, buf_width); // Current timestep.
518  // Zero the state in the major direction only at the end of every row.
519  if (at_last_x) {
520  ZeroVector<double>(na_, curr_sourceerr);
521  ZeroVector<double>(ns_, curr_stateerr);
522  }
523  // Setup the outputerr.
524  if (type_ == NT_LSTM_SUMMARY) {
525  if (dest_index.IsLast(FD_WIDTH)) {
526  fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
527  src_index.Decrement();
528  } else {
529  ZeroVector<double>(ns_, outputerr);
530  }
531  } else if (softmax_ == nullptr) {
532  fwd_deltas.ReadTimeStep(t, outputerr);
533  } else {
534  softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors,
535  softmax_errors_t.get(), outputerr);
536  }
537  if (!at_last_x)
538  AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
539  if (down_pos >= 0)
540  AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
541  // Apply the 1-d forget gates.
542  if (!at_last_x) {
543  const float* next_node_gf1 = node_values_[GF1].f(t + 1);
544  for (int i = 0; i < ns_; ++i) {
545  curr_stateerr[i] *= next_node_gf1[i];
546  }
547  }
548  if (Is2D() && t + 1 < width) {
549  for (int i = 0; i < ns_; ++i) {
550  if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
551  }
552  if (down_pos >= 0) {
553  const float* right_node_gfs = node_values_[GFS].f(down_pos);
554  const double* right_stateerr = stateerr[mod_t];
555  for (int i = 0; i < ns_; ++i) {
556  if (which_fg_[down_pos][i] == 2) {
557  curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
558  }
559  }
560  }
561  }
562  state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr,
563  curr_stateerr);
564  // Clip stateerr_ to a sane range.
565  ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
566 #if DEBUG_DETAIL > 1
567  if (t + 10 > width) {
568  tprintf("t=%d, stateerr=", t);
569  for (int i = 0; i < ns_; ++i)
570  tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i],
571  curr_sourceerr[ni_ + nf_ + i]);
572  tprintf("\n");
573  }
574 #endif
575  // Matrix multiply to get the source errors.
577 
578  // Cell inputs.
579  node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t,
580  curr_stateerr, gate_errors[CI]);
581  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
582  gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
583  gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
584 
586  // Input Gates.
587  node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t,
588  curr_stateerr, gate_errors[GI]);
589  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
590  gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
591  gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
592 
594  // 1-D forget Gates.
595  if (t > 0) {
596  node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
597  gate_errors[GF1]);
598  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
599  gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1],
600  sourceerr_temps[GF1]);
601  } else {
602  memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
603  memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
604  }
605  gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
606 
607  // 2-D forget Gates.
608  if (up_pos >= 0) {
609  node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
610  gate_errors[GFS]);
611  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
612  gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS],
613  sourceerr_temps[GFS]);
614  } else {
615  memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
616  memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
617  }
618  if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
619 
621  // Output gates.
622  state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr,
623  gate_errors[GO]);
624  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
625  gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
626  gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
628 
629  SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
630  sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
631  curr_sourceerr);
632  back_deltas->WriteTimeStep(t, curr_sourceerr);
633  // Save states for use by the 2nd dimension only if needed.
634  if (Is2D()) {
635  CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
636  CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
637  }
638  } while (dest_index.Decrement());
639 #if DEBUG_DETAIL > 2
640  for (int w = 0; w < WT_COUNT; ++w) {
641  tprintf("%s gate errors[%d]\n", name_.c_str(), w);
642  gate_errors_t[w].get()->PrintUnTransposed(10);
643  }
644 #endif
645  // Transposed source_ used to speed-up SumOuter.
646  NetworkScratch::GradientStore source_t, state_t;
647  source_t.Init(na_, width, scratch);
648  source_.Transpose(source_t.get());
649  state_t.Init(ns_, width, scratch);
650  state_.Transpose(state_t.get());
651 #ifdef _OPENMP
652 #pragma omp parallel for num_threads(GFS) if (!Is2D())
653 #endif
654  for (int w = 0; w < WT_COUNT; ++w) {
655  if (w == GFS && !Is2D()) continue;
656  gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
657  }
658  if (softmax_ != nullptr) {
659  softmax_->FinishBackward(*softmax_errors_t);
660  }
661  return needs_to_backprop_;
662 }

◆ ConvertToInt()

void tesseract::LSTM::ConvertToInt ( )
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 182 of file lstm.cpp.

182  {
183  for (int w = 0; w < WT_COUNT; ++w) {
184  if (w == GFS && !Is2D()) continue;
185  gate_weights_[w].ConvertToInt();
186  }
187  if (softmax_ != nullptr) {
188  softmax_->ConvertToInt();
189  }
190 }

◆ CountAlternators()

void tesseract::LSTM::CountAlternators ( const Network other,
double *  same,
double *  changed 
) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 686 of file lstm.cpp.

687  {
688  ASSERT_HOST(other.type() == type_);
689  const LSTM* lstm = static_cast<const LSTM*>(&other);
690  for (int w = 0; w < WT_COUNT; ++w) {
691  if (w == GFS && !Is2D()) continue;
692  gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
693  }
694  if (softmax_ != nullptr) {
695  softmax_->CountAlternators(*lstm->softmax_, same, changed);
696  }
697 }

◆ DebugWeights()

void tesseract::LSTM::DebugWeights ( )
overridevirtual

Implements tesseract::Network.

Definition at line 193 of file lstm.cpp.

193  {
194  for (int w = 0; w < WT_COUNT; ++w) {
195  if (w == GFS && !Is2D()) continue;
196  STRING msg = name_;
197  msg.add_str_int(" Gate weights ", w);
198  gate_weights_[w].Debug2D(msg.c_str());
199  }
200  if (softmax_ != nullptr) {
201  softmax_->DebugWeights();
202  }
203 }

◆ DeSerialize()

bool tesseract::LSTM::DeSerialize ( TFile fp)
overridevirtual

Implements tesseract::Network.

Definition at line 219 of file lstm.cpp.

219  {
220  if (!fp->DeSerialize(&na_)) return false;
221  if (type_ == NT_LSTM_SOFTMAX) {
222  nf_ = no_;
223  } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
224  nf_ = ceil_log2(no_);
225  } else {
226  nf_ = 0;
227  }
228  is_2d_ = false;
229  for (int w = 0; w < WT_COUNT; ++w) {
230  if (w == GFS && !Is2D()) continue;
231  if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) return false;
232  if (w == CI) {
233  ns_ = gate_weights_[CI].NumOutputs();
234  is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
235  }
236  }
237  delete softmax_;
239  softmax_ = static_cast<FullyConnected*>(Network::CreateFromFile(fp));
240  if (softmax_ == nullptr) return false;
241  } else {
242  softmax_ = nullptr;
243  }
244  return true;
245 }

◆ Forward()

void tesseract::LSTM::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
overridevirtual

Implements tesseract::Network.

Definition at line 249 of file lstm.cpp.

251  {
252  input_map_ = input.stride_map();
253  input_width_ = input.Width();
254  if (softmax_ != nullptr)
255  output->ResizeFloat(input, no_);
256  else if (type_ == NT_LSTM_SUMMARY)
257  output->ResizeXTo1(input, no_);
258  else
259  output->Resize(input, no_);
260  ResizeForward(input);
261  // Temporary storage of forward computation for each gate.
262  NetworkScratch::FloatVec temp_lines[WT_COUNT];
263  for (auto & temp_line : temp_lines) temp_line.Init(ns_, scratch);
264  // Single timestep buffers for the current/recurrent output and state.
265  NetworkScratch::FloatVec curr_state, curr_output;
266  curr_state.Init(ns_, scratch);
267  ZeroVector<double>(ns_, curr_state);
268  curr_output.Init(ns_, scratch);
269  ZeroVector<double>(ns_, curr_output);
270  // Rotating buffers of width buf_width allow storage of the state and output
271  // for the other dimension, used only when working in true 2D mode. The width
272  // is enough to hold an entire strip of the major direction.
273  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
275  if (Is2D()) {
276  states.init_to_size(buf_width, NetworkScratch::FloatVec());
277  outputs.init_to_size(buf_width, NetworkScratch::FloatVec());
278  for (int i = 0; i < buf_width; ++i) {
279  states[i].Init(ns_, scratch);
280  ZeroVector<double>(ns_, states[i]);
281  outputs[i].Init(ns_, scratch);
282  ZeroVector<double>(ns_, outputs[i]);
283  }
284  }
285  // Used only if a softmax LSTM.
286  NetworkScratch::FloatVec softmax_output;
287  NetworkScratch::IO int_output;
288  if (softmax_ != nullptr) {
289  softmax_output.Init(no_, scratch);
290  ZeroVector<double>(no_, softmax_output);
291  int rounded_softmax_inputs = gate_weights_[CI].RoundInputs(ns_);
292  if (input.int_mode())
293  int_output.Resize2d(true, 1, rounded_softmax_inputs, scratch);
294  softmax_->SetupForward(input, nullptr);
295  }
296  NetworkScratch::FloatVec curr_input;
297  curr_input.Init(na_, scratch);
298  StrideMap::Index src_index(input_map_);
299  // Used only by NT_LSTM_SUMMARY.
300  StrideMap::Index dest_index(output->stride_map());
301  do {
302  int t = src_index.t();
303  // True if there is a valid old state for the 2nd dimension.
304  bool valid_2d = Is2D();
305  if (valid_2d) {
306  StrideMap::Index dim_index(src_index);
307  if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false;
308  }
309  // Index of the 2-D revolving buffers (outputs, states).
310  int mod_t = Modulo(t, buf_width); // Current timestep.
311  // Setup the padded input in source.
312  source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
313  if (softmax_ != nullptr) {
314  source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
315  }
316  source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
317  if (Is2D())
318  source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
319  if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input);
320  // Matrix multiply the inputs with the source.
322  // It looks inefficient to create the threads on each t iteration, but the
323  // alternative of putting the parallel outside the t loop, a single around
324  // the t-loop and then tasks in place of the sections is a *lot* slower.
325  // Cell inputs.
326  if (source_.int_mode())
327  gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
328  else
329  gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
330  FuncInplace<GFunc>(ns_, temp_lines[CI]);
331 
333  // Input Gates.
334  if (source_.int_mode())
335  gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
336  else
337  gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
338  FuncInplace<FFunc>(ns_, temp_lines[GI]);
339 
341  // 1-D forget gates.
342  if (source_.int_mode())
343  gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
344  else
345  gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
346  FuncInplace<FFunc>(ns_, temp_lines[GF1]);
347 
348  // 2-D forget gates.
349  if (Is2D()) {
350  if (source_.int_mode())
351  gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
352  else
353  gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
354  FuncInplace<FFunc>(ns_, temp_lines[GFS]);
355  }
356 
358  // Output gates.
359  if (source_.int_mode())
360  gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
361  else
362  gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
363  FuncInplace<FFunc>(ns_, temp_lines[GO]);
365 
366  // Apply forget gate to state.
367  MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
368  if (Is2D()) {
369  // Max-pool the forget gates (in 2-d) instead of blindly adding.
370  int8_t* which_fg_col = which_fg_[t];
371  memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
372  if (valid_2d) {
373  const double* stepped_state = states[mod_t];
374  for (int i = 0; i < ns_; ++i) {
375  if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
376  curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
377  which_fg_col[i] = 2;
378  }
379  }
380  }
381  }
382  MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
383  // Clip curr_state to a sane range.
384  ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
385  if (IsTraining()) {
386  // Save the gate node values.
387  node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
388  node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
389  node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
390  node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
391  if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
392  }
393  FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
394  if (IsTraining()) state_.WriteTimeStep(t, curr_state);
395  if (softmax_ != nullptr) {
396  if (input.int_mode()) {
397  int_output->WriteTimeStepPart(0, 0, ns_, curr_output);
398  softmax_->ForwardTimeStep(int_output->i(0), t, softmax_output);
399  } else {
400  softmax_->ForwardTimeStep(curr_output, t, softmax_output);
401  }
402  output->WriteTimeStep(t, softmax_output);
404  CodeInBinary(no_, nf_, softmax_output);
405  }
406  } else if (type_ == NT_LSTM_SUMMARY) {
407  // Output only at the end of a row.
408  if (src_index.IsLast(FD_WIDTH)) {
409  output->WriteTimeStep(dest_index.t(), curr_output);
410  dest_index.Increment();
411  }
412  } else {
413  output->WriteTimeStep(t, curr_output);
414  }
415  // Save states for use by the 2nd dimension only if needed.
416  if (Is2D()) {
417  CopyVector(ns_, curr_state, states[mod_t]);
418  CopyVector(ns_, curr_output, outputs[mod_t]);
419  }
420  // Always zero the states at the end of every row, but only for the major
421  // direction. The 2-D state remains intact.
422  if (src_index.IsLast(FD_WIDTH)) {
423  ZeroVector<double>(ns_, curr_state);
424  ZeroVector<double>(ns_, curr_output);
425  }
426  } while (src_index.Increment());
427 #if DEBUG_DETAIL > 0
428  tprintf("Source:%s\n", name_.c_str());
429  source_.Print(10);
430  tprintf("State:%s\n", name_.c_str());
431  state_.Print(10);
432  tprintf("Output:%s\n", name_.c_str());
433  output->Print(10);
434 #endif
435  if (debug) DisplayForward(*output);
436 }

◆ InitWeights()

int tesseract::LSTM::InitWeights ( float  range,
TRand randomizer 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 157 of file lstm.cpp.

157  {
158  Network::SetRandomizer(randomizer);
159  num_weights_ = 0;
160  for (int w = 0; w < WT_COUNT; ++w) {
161  if (w == GFS && !Is2D()) continue;
162  num_weights_ += gate_weights_[w].InitWeightsFloat(
163  ns_, na_ + 1, TestFlag(NF_ADAM), range, randomizer);
164  }
165  if (softmax_ != nullptr) {
166  num_weights_ += softmax_->InitWeights(range, randomizer);
167  }
168  return num_weights_;
169 }

◆ Is2D()

bool tesseract::LSTM::Is2D ( ) const
inline

Definition at line 119 of file lstm.h.

119  {
120  return is_2d_;
121  }

◆ OutputShape()

StaticShape tesseract::LSTM::OutputShape ( const StaticShape input_shape) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 126 of file lstm.cpp.

126  {
127  StaticShape result = input_shape;
128  result.set_depth(no_);
129  if (type_ == NT_LSTM_SUMMARY) result.set_width(1);
130  if (softmax_ != nullptr) return softmax_->OutputShape(result);
131  return result;
132 }

◆ PrintDW()

void tesseract::LSTM::PrintDW ( )

Definition at line 726 of file lstm.cpp.

726  {
727  tprintf("Delta state:%s\n", name_.c_str());
728  for (int w = 0; w < WT_COUNT; ++w) {
729  if (w == GFS && !Is2D()) continue;
730  tprintf("Gate %d, inputs\n", w);
731  for (int i = 0; i < ni_; ++i) {
732  tprintf("Row %d:", i);
733  for (int s = 0; s < ns_; ++s)
734  tprintf(" %g", gate_weights_[w].GetDW(s, i));
735  tprintf("\n");
736  }
737  tprintf("Gate %d, outputs\n", w);
738  for (int i = ni_; i < ni_ + ns_; ++i) {
739  tprintf("Row %d:", i - ni_);
740  for (int s = 0; s < ns_; ++s)
741  tprintf(" %g", gate_weights_[w].GetDW(s, i));
742  tprintf("\n");
743  }
744  tprintf("Gate %d, bias\n", w);
745  for (int s = 0; s < ns_; ++s)
746  tprintf(" %g", gate_weights_[w].GetDW(s, na_));
747  tprintf("\n");
748  }
749 }

◆ PrintW()

void tesseract::LSTM::PrintW ( )

Definition at line 700 of file lstm.cpp.

700  {
701  tprintf("Weight state:%s\n", name_.c_str());
702  for (int w = 0; w < WT_COUNT; ++w) {
703  if (w == GFS && !Is2D()) continue;
704  tprintf("Gate %d, inputs\n", w);
705  for (int i = 0; i < ni_; ++i) {
706  tprintf("Row %d:", i);
707  for (int s = 0; s < ns_; ++s)
708  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
709  tprintf("\n");
710  }
711  tprintf("Gate %d, outputs\n", w);
712  for (int i = ni_; i < ni_ + ns_; ++i) {
713  tprintf("Row %d:", i - ni_);
714  for (int s = 0; s < ns_; ++s)
715  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
716  tprintf("\n");
717  }
718  tprintf("Gate %d, bias\n", w);
719  for (int s = 0; s < ns_; ++s)
720  tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
721  tprintf("\n");
722  }
723 }

◆ RemapOutputs()

int tesseract::LSTM::RemapOutputs ( int  old_no,
const std::vector< int > &  code_map 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 173 of file lstm.cpp.

173  {
174  if (softmax_ != nullptr) {
175  num_weights_ -= softmax_->num_weights();
176  num_weights_ += softmax_->RemapOutputs(old_no, code_map);
177  }
178  return num_weights_;
179 }

◆ Serialize()

bool tesseract::LSTM::Serialize ( TFile fp) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 206 of file lstm.cpp.

206  {
207  if (!Network::Serialize(fp)) return false;
208  if (!fp->Serialize(&na_)) return false;
209  for (int w = 0; w < WT_COUNT; ++w) {
210  if (w == GFS && !Is2D()) continue;
211  if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
212  }
213  if (softmax_ != nullptr && !softmax_->Serialize(fp)) return false;
214  return true;
215 }

◆ SetEnableTraining()

void tesseract::LSTM::SetEnableTraining ( TrainingState  state)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 136 of file lstm.cpp.

136  {
137  if (state == TS_RE_ENABLE) {
138  // Enable only from temp disabled.
140  } else if (state == TS_TEMP_DISABLE) {
141  // Temp disable only from enabled.
142  if (training_ == TS_ENABLED) training_ = state;
143  } else {
144  if (state == TS_ENABLED && training_ != TS_ENABLED) {
145  for (int w = 0; w < WT_COUNT; ++w) {
146  if (w == GFS && !Is2D()) continue;
147  gate_weights_[w].InitBackward();
148  }
149  }
150  training_ = state;
151  }
152  if (softmax_ != nullptr) softmax_->SetEnableTraining(state);
153 }

◆ spec()

STRING tesseract::LSTM::spec ( ) const
inlineoverridevirtual

Reimplemented from tesseract::Network.

Definition at line 58 of file lstm.h.

58  {
59  STRING spec;
60  if (type_ == NT_LSTM)
61  spec.add_str_int("Lfx", ns_);
62  else if (type_ == NT_LSTM_SUMMARY)
63  spec.add_str_int("Lfxs", ns_);
64  else if (type_ == NT_LSTM_SOFTMAX)
65  spec.add_str_int("LS", ns_);
66  else if (type_ == NT_LSTM_SOFTMAX_ENCODED)
67  spec.add_str_int("LE", ns_);
68  if (softmax_ != nullptr) spec += softmax_->spec();
69  return spec;
70  }

◆ Update()

void tesseract::LSTM::Update ( float  learning_rate,
float  momentum,
float  adam_beta,
int  num_samples 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 666 of file lstm.cpp.

667  {
668 #if DEBUG_DETAIL > 3
669  PrintW();
670 #endif
671  for (int w = 0; w < WT_COUNT; ++w) {
672  if (w == GFS && !Is2D()) continue;
673  gate_weights_[w].Update(learning_rate, momentum, adam_beta, num_samples);
674  }
675  if (softmax_ != nullptr) {
676  softmax_->Update(learning_rate, momentum, adam_beta, num_samples);
677  }
678 #if DEBUG_DETAIL > 3
679  PrintDW();
680 #endif
681 }

The documentation for this class was generated from the following files:
tesseract::FullyConnected::InitWeights
int InitWeights(float range, TRand *randomizer) override
Definition: fullyconnected.cpp:76
tesseract::TS_ENABLED
Definition: network.h:95
tesseract::LSTM::PrintW
void PrintW()
Definition: lstm.cpp:700
tesseract::WeightMatrix::CountAlternators
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
Definition: weightmatrix.cpp:346
tesseract::LSTM::Serialize
bool Serialize(TFile *fp) const override
Definition: lstm.cpp:206
tesseract::LSTM::GO
Definition: lstm.h:37
tesseract::StrideMap::Size
int Size(FlexDimensions dimension) const
Definition: stridemap.h:114
tesseract::NetworkIO::i
const int8_t * i(int t) const
Definition: networkio.h:123
tesseract::FullyConnected::FinishBackward
void FinishBackward(const TransposedArray &errors_t)
Definition: fullyconnected.cpp:288
tesseract::Network::SetRandomizer
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
STRING::add_str_int
void add_str_int(const char *str, int number)
Definition: strngs.cpp:370
tesseract::Network::DisplayForward
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:288
SECTION_IF_OPENMP
#define SECTION_IF_OPENMP
Definition: lstm.cpp:60
tesseract::FullyConnected::Update
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: fullyconnected.cpp:297
tesseract::NetworkIO::int_mode
bool int_mode() const
Definition: networkio.h:127
tesseract::CopyVector
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:169
tesseract::FullyConnected::SetupForward
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
Definition: fullyconnected.cpp:172
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:87
tesseract::FullyConnected::CountAlternators
void CountAlternators(const Network &other, double *same, double *changed) const override
Definition: fullyconnected.cpp:305
tesseract::LSTM::GF1
Definition: lstm.h:36
tesseract::WeightMatrix::SumOuterTransposed
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
Definition: weightmatrix.cpp:284
tesseract::FullyConnected::RemapOutputs
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: fullyconnected.cpp:86
STRING
Definition: strngs.h:45
tesseract::MultiplyVectorsInPlace
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
Definition: functions.h:179
tesseract::WeightMatrix::Update
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
Definition: weightmatrix.cpp:314
tesseract::kErrClip
const double kErrClip
Definition: lstm.cpp:71
tesseract::Network::type
NetworkType type() const
Definition: network.h:112
tesseract::MultiplyAccumulate
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
Definition: functions.h:184
tesseract::NetworkIO::FuncMultiply3Add
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:299
tesseract::LSTM::WT_COUNT
Definition: lstm.h:40
tesseract::Network::needs_to_backprop_
bool needs_to_backprop_
Definition: network.h:295
tesseract::NT_LSTM
Definition: network.h:60
tesseract::WeightMatrix::Debug2D
void Debug2D(const char *msg)
Definition: weightmatrix.cpp:377
tesseract::LSTM::LSTM
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
Definition: lstm.cpp:98
tesseract::WeightMatrix::InitBackward
void InitBackward()
Definition: weightmatrix.cpp:153
tesseract::Network::TestFlag
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
tesseract::Network::IsTraining
bool IsTraining() const
Definition: network.h:115
tesseract::LSTM::DeSerialize
bool DeSerialize(TFile *fp) override
Definition: lstm.cpp:219
tesseract::FullyConnected::SetEnableTraining
void SetEnableTraining(TrainingState state) override
Definition: fullyconnected.cpp:60
tesseract::NetworkIO::Transpose
void Transpose(TransposedArray *dest) const
Definition: networkio.cpp:964
tesseract::FullyConnected::BackwardTimeStep
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
Definition: fullyconnected.cpp:264
tesseract::Network::name_
STRING name_
Definition: network.h:300
tesseract::Network::CreateFromFile
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:187
tesseract::LSTM::spec
STRING spec() const override
Definition: lstm.h:58
tesseract::Network::type_
NetworkType type_
Definition: network.h:293
tesseract::NF_ADAM
Definition: network.h:88
tesseract::WeightMatrix::RoundInputs
int RoundInputs(int size) const
Definition: weightmatrix.h:92
STRING::c_str
const char * c_str() const
Definition: strngs.cpp:192
tesseract::CodeInBinary
void CodeInBinary(int n, int nf, double *vec)
Definition: functions.h:214
tesseract::FullyConnected::ForwardTimeStep
void ForwardTimeStep(int t, double *output_line)
Definition: fullyconnected.cpp:184
tesseract::FD_WIDTH
Definition: stridemap.h:35
PARALLEL_IF_OPENMP
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:59
tesseract::FullyConnected::OutputShape
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: fullyconnected.cpp:45
tesseract::NetworkIO::f
float * f(int t)
Definition: networkio.h:115
tesseract::NetworkIO::Print
void Print(int num) const
Definition: networkio.cpp:366
tesseract::LSTM::GI
Definition: lstm.h:35
tesseract::WeightMatrix::ConvertToInt
void ConvertToInt()
Definition: weightmatrix.cpp:125
tesseract::WeightMatrix::InitWeightsFloat
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
Definition: weightmatrix.cpp:76
tesseract::Network::training_
TrainingState training_
Definition: network.h:294
tesseract::FullyConnected::Serialize
bool Serialize(TFile *fp) const override
Definition: fullyconnected.cpp:105
tesseract::LSTM::GFS
Definition: lstm.h:38
tesseract::TS_RE_ENABLE
Definition: network.h:99
tesseract::LSTM::PrintDW
void PrintDW()
Definition: lstm.cpp:726
tesseract::NetworkIO::ReadTimeStep
void ReadTimeStep(int t, double *output) const
Definition: networkio.cpp:598
GenericVector
Definition: baseapi.h:40
tesseract::NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:76
tesseract::NetworkIO::CopyTimeStepGeneral
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
Definition: networkio.cpp:393
tesseract::ClipVector
void ClipVector(int n, T lower, T upper, T *vec)
Definition: functions.h:208
tesseract::FullyConnected::ConvertToInt
void ConvertToInt() override
Definition: fullyconnected.cpp:95
tesseract::NetworkIO::WriteTimeStep
void WriteTimeStep(int t, const double *input)
Definition: networkio.cpp:645
tesseract::FullyConnected::spec
STRING spec() const override
Definition: fullyconnected.h:37
tesseract::Network::num_weights_
int32_t num_weights_
Definition: network.h:299
tesseract::AccumulateVector
void AccumulateVector(int n, const double *src, double *dest)
Definition: functions.h:174
tesseract::WeightMatrix::MatrixDotVector
void MatrixDotVector(const double *u, double *v) const
Definition: weightmatrix.cpp:243
tesseract::Network::name
const STRING & name() const
Definition: network.h:138
tesseract::FD_HEIGHT
Definition: stridemap.h:34
tesseract::LSTM::CI
Definition: lstm.h:34
tesseract::TS_TEMP_DISABLE
Definition: network.h:97
tesseract::NT_LSTM_SOFTMAX
Definition: network.h:75
GenericVector::init_to_size
void init_to_size(int size, const T &t)
Definition: genericvector.h:706
tesseract::WeightMatrix::VectorDotMatrix
void VectorDotMatrix(const double *u, double *v) const
Definition: weightmatrix.cpp:274
tesseract::NT_LSTM_SUMMARY
Definition: network.h:61
Modulo
int Modulo(int a, int b)
Definition: helpers.h:156
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
tesseract::kStateClip
const double kStateClip
Definition: lstm.cpp:69
tesseract::SumVectors
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
Definition: functions.h:192
tesseract::WeightMatrix::NumOutputs
int NumOutputs() const
Definition: weightmatrix.h:101
END_PARALLEL_IF_OPENMP
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:61
tesseract::Network::Serialize
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
tesseract::Network::no_
int32_t no_
Definition: network.h:298
tesseract::Network::ni_
int32_t ni_
Definition: network.h:297
tesseract::NetworkIO::WriteTimeStepPart
void WriteTimeStepPart(int t, int offset, int num_features, const double *input)
Definition: networkio.cpp:651
tesseract::Network::num_weights
int num_weights() const
Definition: network.h:119
tesseract::Network::DisplayBackward
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:299
tesseract::FullyConnected::DebugWeights
void DebugWeights() override
Definition: fullyconnected.cpp:100
tesseract::NT_SOFTMAX
Definition: network.h:68
tesseract::LSTM::Is2D
bool Is2D() const
Definition: lstm.h:119
tesseract::Network::Network
Network()
Definition: network.cpp:76