tesseract  5.0.0-alpha-619-ge9db
tesseract::LSTMTrainer Class Reference

#include <lstmtrainer.h>

Inheritance diagram for tesseract::LSTMTrainer:
tesseract::LSTMRecognizer

Public Member Functions

 LSTMTrainer ()
 
 LSTMTrainer (const char *model_base, const char *checkpoint_name, int debug_interval, int64_t max_memory)
 
virtual ~LSTMTrainer ()
 
bool TryLoadingCheckpoint (const char *filename, const char *old_traineddata)
 
void InitCharSet (const std::string &traineddata_path)
 
void InitCharSet (const TessdataManager &mgr)
 
bool InitNetwork (const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
 
int InitTensorFlowNetwork (const std::string &tf_proto)
 
void InitIterations ()
 
double ActivationError () const
 
double CharError () const
 
const double * error_rates () const
 
double best_error_rate () const
 
int best_iteration () const
 
int learning_iteration () const
 
int32_t improvement_steps () const
 
void set_perfect_delay (int delay)
 
const GenericVector< char > & best_trainer () const
 
double NewSingleError (ErrorTypes type) const
 
double LastSingleError (ErrorTypes type) const
 
const DocumentCachetraining_data () const
 
DocumentCachemutable_training_data ()
 
Trainability GridSearchDictParams (const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
 
void DebugNetwork ()
 
bool LoadAllTrainingData (const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
 
bool MaintainCheckpoints (TestCallback tester, STRING *log_msg)
 
bool MaintainCheckpointsSpecific (int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg)
 
void PrepareLogMsg (STRING *log_msg) const
 
void LogIterations (const char *intro_str, STRING *log_msg) const
 
bool TransitionTrainingStage (float error_threshold)
 
int CurrentTrainingStage () const
 
bool Serialize (SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
 
bool DeSerialize (const TessdataManager *mgr, TFile *fp)
 
void StartSubtrainer (STRING *log_msg)
 
SubTrainerResult UpdateSubtrainer (STRING *log_msg)
 
void ReduceLearningRates (LSTMTrainer *samples_trainer, STRING *log_msg)
 
int ReduceLayerLearningRates (double factor, int num_samples, LSTMTrainer *samples_trainer)
 
bool EncodeString (const STRING &str, GenericVector< int > *labels) const
 
const ImageDataTrainOnLine (LSTMTrainer *samples_trainer, bool batch)
 
Trainability TrainOnLine (const ImageData *trainingdata, bool batch)
 
Trainability PrepareForBackward (const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
 
bool SaveTrainingDump (SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
 
bool ReadTrainingDump (const GenericVector< char > &data, LSTMTrainer *trainer) const
 
bool ReadSizedTrainingDump (const char *data, int size, LSTMTrainer *trainer) const
 
bool ReadLocalTrainingDump (const TessdataManager *mgr, const char *data, int size)
 
void SetupCheckpointInfo ()
 
bool SaveTraineddata (const STRING &filename)
 
void SaveRecognitionDump (GenericVector< char > *data) const
 
STRING DumpFilename () const
 
void FillErrorBuffer (double new_error, ErrorTypes type)
 
std::vector< int > MapRecoder (const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
 
- Public Member Functions inherited from tesseract::LSTMRecognizer
 LSTMRecognizer ()
 
 LSTMRecognizer (const STRING language_data_path_prefix)
 
 ~LSTMRecognizer ()
 
int NumOutputs () const
 
int training_iteration () const
 
int sample_iteration () const
 
double learning_rate () const
 
LossType OutputLossType () const
 
bool SimpleTextOutput () const
 
bool IsIntMode () const
 
bool IsRecoding () const
 
bool IsTensorFlow () const
 
GenericVector< STRINGEnumerateLayers () const
 
NetworkGetLayer (const STRING &id) const
 
float GetLayerLearningRate (const STRING &id) const
 
void ScaleLearningRate (double factor)
 
void ScaleLayerLearningRate (const STRING &id, double factor)
 
void ConvertToInt ()
 
const UNICHARSETGetUnicharset () const
 
UNICHARSETGetUnicharset ()
 
const UnicharCompressGetRecoder () const
 
const DictGetDict () const
 
DictGetDict ()
 
void SetIteration (int iteration)
 
int NumInputs () const
 
int null_char () const
 
bool Load (const ParamsVectors *params, const char *lang, TessdataManager *mgr)
 
bool Serialize (const TessdataManager *mgr, TFile *fp) const
 
bool DeSerialize (const TessdataManager *mgr, TFile *fp)
 
bool LoadCharsets (const TessdataManager *mgr)
 
bool LoadRecoder (TFile *fp)
 
bool LoadDictionary (const ParamsVectors *params, const char *lang, TessdataManager *mgr)
 
void RecognizeLine (const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
 
void OutputStats (const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
 
bool RecognizeLine (const ImageData &image_data, bool invert, bool debug, bool re_invert, bool upside_down, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs)
 
STRING DecodeLabels (const GenericVector< int > &labels)
 
void DisplayForward (const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
 
void LabelsFromOutputs (const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
 

Static Public Member Functions

static bool EncodeString (const STRING &str, const UNICHARSET &unicharset, const UnicharCompress *recoder, bool simple_text, int null_char, GenericVector< int > *labels)
 

Protected Member Functions

void InitCharSet ()
 
void SetNullChar ()
 
void EmptyConstructor ()
 
bool DebugLSTMTraining (const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
 
void DisplayTargets (const NetworkIO &targets, const char *window_name, ScrollView **window)
 
bool ComputeTextTargets (const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
 
bool ComputeCTCTargets (const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
 
double ComputeErrorRates (const NetworkIO &deltas, double char_error, double word_error)
 
double ComputeRMSError (const NetworkIO &deltas)
 
double ComputeWinnerError (const NetworkIO &deltas)
 
double ComputeCharError (const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
 
double ComputeWordError (STRING *truth_str, STRING *ocr_str)
 
void UpdateErrorBuffer (double new_error, ErrorTypes type)
 
void RollErrorBuffers ()
 
STRING UpdateErrorGraph (int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
 
- Protected Member Functions inherited from tesseract::LSTMRecognizer
void SetRandomSeed ()
 
void DisplayLSTMOutput (const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)
 
void DebugActivationPath (const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
 
void DebugActivationRange (const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
 
void LabelsViaReEncode (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
void LabelsViaSimpleText (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
const char * DecodeLabel (const GenericVector< int > &labels, int start, int *end, int *decoded)
 
const char * DecodeSingleLabel (int label)
 

Protected Attributes

ScrollViewalign_win_
 
ScrollViewtarget_win_
 
ScrollViewctc_win_
 
ScrollViewrecon_win_
 
int debug_interval_
 
int checkpoint_iteration_
 
STRING model_base_
 
STRING checkpoint_name_
 
bool randomly_rotate_
 
DocumentCache training_data_
 
STRING best_model_name_
 
int num_training_stages_
 
double best_error_rate_
 
double best_error_rates_ [ET_COUNT]
 
int best_iteration_
 
double worst_error_rate_
 
double worst_error_rates_ [ET_COUNT]
 
int worst_iteration_
 
int stall_iteration_
 
GenericVector< char > best_model_data_
 
GenericVector< char > worst_model_data_
 
GenericVector< char > best_trainer_
 
LSTMTrainersub_trainer_
 
float error_rate_of_last_saved_best_
 
int training_stage_
 
GenericVector< double > best_error_history_
 
GenericVector< int > best_error_iterations_
 
int32_t improvement_steps_
 
int learning_iteration_
 
int prev_sample_iteration_
 
int perfect_delay_
 
int last_perfect_training_iteration_
 
GenericVector< double > error_buffers_ [ET_COUNT]
 
double error_rates_ [ET_COUNT]
 
TessdataManager mgr_
 
- Protected Attributes inherited from tesseract::LSTMRecognizer
Networknetwork_
 
CCUtil ccutil_
 
UnicharCompress recoder_
 
STRING network_str_
 
int32_t training_flags_
 
int32_t training_iteration_
 
int32_t sample_iteration_
 
int32_t null_char_
 
float learning_rate_
 
float momentum_
 
float adam_beta_
 
TRand randomizer_
 
NetworkScratch scratch_space_
 
Dictdict_
 
RecodeBeamSearchsearch_
 
ScrollViewdebug_win_
 

Static Protected Attributes

static const int kRollingBufferSize_ = 1000
 

Detailed Description

Definition at line 79 of file lstmtrainer.h.

Constructor & Destructor Documentation

◆ LSTMTrainer() [1/2]

tesseract::LSTMTrainer::LSTMTrainer ( )

Definition at line 74 of file lstmtrainer.cpp.

75  : randomly_rotate_(false),
76  training_data_(0),
77  sub_trainer_(nullptr) {
79  debug_interval_ = 0;
80 }

◆ LSTMTrainer() [2/2]

tesseract::LSTMTrainer::LSTMTrainer ( const char *  model_base,
const char *  checkpoint_name,
int  debug_interval,
int64_t  max_memory 
)

Definition at line 82 of file lstmtrainer.cpp.

84  : randomly_rotate_(false),
85  training_data_(max_memory),
86  sub_trainer_(nullptr) {
88  debug_interval_ = debug_interval;
89  model_base_ = model_base;
90  checkpoint_name_ = checkpoint_name;
91 }

◆ ~LSTMTrainer()

tesseract::LSTMTrainer::~LSTMTrainer ( )
virtual

Definition at line 93 of file lstmtrainer.cpp.

93  {
94  delete align_win_;
95  delete target_win_;
96  delete ctc_win_;
97  delete recon_win_;
98  delete sub_trainer_;
99 }

Member Function Documentation

◆ ActivationError()

double tesseract::LSTMTrainer::ActivationError ( ) const
inline

Definition at line 122 of file lstmtrainer.h.

122  {
123  return error_rates_[ET_DELTA];
124  }

◆ best_error_rate()

double tesseract::LSTMTrainer::best_error_rate ( ) const
inline

Definition at line 129 of file lstmtrainer.h.

129  {
130  return best_error_rate_;
131  }

◆ best_iteration()

int tesseract::LSTMTrainer::best_iteration ( ) const
inline

Definition at line 132 of file lstmtrainer.h.

132  {
133  return best_iteration_;
134  }

◆ best_trainer()

const GenericVector<char>& tesseract::LSTMTrainer::best_trainer ( ) const
inline

Definition at line 138 of file lstmtrainer.h.

138 { return best_trainer_; }

◆ CharError()

double tesseract::LSTMTrainer::CharError ( ) const
inline

Definition at line 125 of file lstmtrainer.h.

125 { return error_rates_[ET_CHAR_ERROR]; }

◆ ComputeCharError()

double tesseract::LSTMTrainer::ComputeCharError ( const GenericVector< int > &  truth_str,
const GenericVector< int > &  ocr_str 
)
protected

Definition at line 1167 of file lstmtrainer.cpp.

1168  {
1169  GenericVector<int> label_counts;
1170  label_counts.init_to_size(NumOutputs(), 0);
1171  int truth_size = 0;
1172  for (int i = 0; i < truth_str.size(); ++i) {
1173  if (truth_str[i] != null_char_) {
1174  ++label_counts[truth_str[i]];
1175  ++truth_size;
1176  }
1177  }
1178  for (int i = 0; i < ocr_str.size(); ++i) {
1179  if (ocr_str[i] != null_char_) {
1180  --label_counts[ocr_str[i]];
1181  }
1182  }
1183  int char_errors = 0;
1184  for (int i = 0; i < label_counts.size(); ++i) {
1185  char_errors += abs(label_counts[i]);
1186  }
1187  if (truth_size == 0) {
1188  return (char_errors == 0) ? 0.0 : 1.0;
1189  }
1190  return static_cast<double>(char_errors) / truth_size;
1191 }

◆ ComputeCTCTargets()

bool tesseract::LSTMTrainer::ComputeCTCTargets ( const GenericVector< int > &  truth_labels,
NetworkIO outputs,
NetworkIO targets 
)
protected

Definition at line 1099 of file lstmtrainer.cpp.

1100  {
1101  // Bottom-clip outputs to a minimum probability.
1102  CTC::NormalizeProbs(outputs);
1103  return CTC::ComputeCTCTargets(truth_labels, null_char_,
1104  outputs->float_array(), targets);
1105 }

◆ ComputeErrorRates()

double tesseract::LSTMTrainer::ComputeErrorRates ( const NetworkIO deltas,
double  char_error,
double  word_error 
)
protected

Definition at line 1110 of file lstmtrainer.cpp.

1111  {
1113  // Delta error is the fraction of timesteps with >0.5 error in the top choice
1114  // score. If zero, then the top choice characters are guaranteed correct,
1115  // even when there is residue in the RMS error.
1116  double delta_error = ComputeWinnerError(deltas);
1117  UpdateErrorBuffer(delta_error, ET_DELTA);
1118  UpdateErrorBuffer(word_error, ET_WORD_RECERR);
1119  UpdateErrorBuffer(char_error, ET_CHAR_ERROR);
1120  // Skip ratio measures the difference between sample_iteration_ and
1121  // training_iteration_, which reflects the number of unusable samples,
1122  // usually due to unencodable truth text, or the text not fitting in the
1123  // space for the output.
1124  double skip_count = sample_iteration_ - prev_sample_iteration_;
1125  UpdateErrorBuffer(skip_count, ET_SKIP_RATIO);
1126  return delta_error;
1127 }

◆ ComputeRMSError()

double tesseract::LSTMTrainer::ComputeRMSError ( const NetworkIO deltas)
protected

Definition at line 1130 of file lstmtrainer.cpp.

1130  {
1131  double total_error = 0.0;
1132  int width = deltas.Width();
1133  int num_classes = deltas.NumFeatures();
1134  for (int t = 0; t < width; ++t) {
1135  const float* class_errs = deltas.f(t);
1136  for (int c = 0; c < num_classes; ++c) {
1137  double error = class_errs[c];
1138  total_error += error * error;
1139  }
1140  }
1141  return sqrt(total_error / (width * num_classes));
1142 }

◆ ComputeTextTargets()

bool tesseract::LSTMTrainer::ComputeTextTargets ( const NetworkIO outputs,
const GenericVector< int > &  truth_labels,
NetworkIO targets 
)
protected

Definition at line 1079 of file lstmtrainer.cpp.

1081  {
1082  if (truth_labels.size() > targets->Width()) {
1083  tprintf("Error: transcription %s too long to fit into target of width %d\n",
1084  DecodeLabels(truth_labels).c_str(), targets->Width());
1085  return false;
1086  }
1087  for (int i = 0; i < truth_labels.size() && i < targets->Width(); ++i) {
1088  targets->SetActivations(i, truth_labels[i], 1.0);
1089  }
1090  for (int i = truth_labels.size(); i < targets->Width(); ++i) {
1091  targets->SetActivations(i, null_char_, 1.0);
1092  }
1093  return true;
1094 }

◆ ComputeWinnerError()

double tesseract::LSTMTrainer::ComputeWinnerError ( const NetworkIO deltas)
protected

Definition at line 1149 of file lstmtrainer.cpp.

1149  {
1150  int num_errors = 0;
1151  int width = deltas.Width();
1152  int num_classes = deltas.NumFeatures();
1153  for (int t = 0; t < width; ++t) {
1154  const float* class_errs = deltas.f(t);
1155  for (int c = 0; c < num_classes; ++c) {
1156  float abs_delta = fabs(class_errs[c]);
1157  // TODO(rays) Filtering cases where the delta is very large to cut out
1158  // GT errors doesn't work. Find a better way or get better truth.
1159  if (0.5 <= abs_delta)
1160  ++num_errors;
1161  }
1162  }
1163  return static_cast<double>(num_errors) / width;
1164 }

◆ ComputeWordError()

double tesseract::LSTMTrainer::ComputeWordError ( STRING truth_str,
STRING ocr_str 
)
protected

Definition at line 1195 of file lstmtrainer.cpp.

1195  {
1196  using StrMap = std::unordered_map<std::string, int, std::hash<std::string>>;
1197  GenericVector<STRING> truth_words, ocr_words;
1198  truth_str->split(' ', &truth_words);
1199  if (truth_words.empty()) return 0.0;
1200  ocr_str->split(' ', &ocr_words);
1201  StrMap word_counts;
1202  for (int i = 0; i < truth_words.size(); ++i) {
1203  std::string truth_word(truth_words[i].c_str());
1204  auto it = word_counts.find(truth_word);
1205  if (it == word_counts.end())
1206  word_counts.insert(std::make_pair(truth_word, 1));
1207  else
1208  ++it->second;
1209  }
1210  for (int i = 0; i < ocr_words.size(); ++i) {
1211  std::string ocr_word(ocr_words[i].c_str());
1212  auto it = word_counts.find(ocr_word);
1213  if (it == word_counts.end())
1214  word_counts.insert(std::make_pair(ocr_word, -1));
1215  else
1216  --it->second;
1217  }
1218  int word_recall_errs = 0;
1219  for (StrMap::const_iterator it = word_counts.begin(); it != word_counts.end();
1220  ++it) {
1221  if (it->second > 0) word_recall_errs += it->second;
1222  }
1223  return static_cast<double>(word_recall_errs) / truth_words.size();
1224 }

◆ CurrentTrainingStage()

int tesseract::LSTMTrainer::CurrentTrainingStage ( ) const
inline

Definition at line 197 of file lstmtrainer.h.

197 { return training_stage_; }

◆ DebugLSTMTraining()

bool tesseract::LSTMTrainer::DebugLSTMTraining ( const NetworkIO inputs,
const ImageData trainingdata,
const NetworkIO fwd_outputs,
const GenericVector< int > &  truth_labels,
const NetworkIO outputs 
)
protected

Definition at line 1005 of file lstmtrainer.cpp.

1009  {
1010  const STRING& truth_text = DecodeLabels(truth_labels);
1011  if (truth_text.c_str() == nullptr || truth_text.length() <= 0) {
1012  tprintf("Empty truth string at decode time!\n");
1013  return false;
1014  }
1015  if (debug_interval_ != 0) {
1016  // Get class labels, xcoords and string.
1017  GenericVector<int> labels;
1018  GenericVector<int> xcoords;
1019  LabelsFromOutputs(outputs, &labels, &xcoords);
1020  STRING text = DecodeLabels(labels);
1021  tprintf("Iteration %d: GROUND TRUTH : %s\n",
1022  training_iteration(), truth_text.c_str());
1023  if (truth_text != text) {
1024  tprintf("Iteration %d: ALIGNED TRUTH : %s\n",
1025  training_iteration(), text.c_str());
1026  }
1027  if (debug_interval_ > 0 && training_iteration() % debug_interval_ == 0) {
1028  tprintf("TRAINING activation path for truth string %s\n",
1029  truth_text.c_str());
1030  DebugActivationPath(outputs, labels, xcoords);
1031  DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_);
1032  if (OutputLossType() == LT_CTC) {
1033  DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_);
1034  DisplayTargets(outputs, "CTC Targets", &target_win_);
1035  }
1036  }
1037  }
1038  return true;
1039 }

◆ DebugNetwork()

void tesseract::LSTMTrainer::DebugNetwork ( )

Definition at line 265 of file lstmtrainer.cpp.

265  {
267 }

◆ DeSerialize()

bool tesseract::LSTMTrainer::DeSerialize ( const TessdataManager mgr,
TFile fp 
)

Definition at line 440 of file lstmtrainer.cpp.

440  {
441  if (!LSTMRecognizer::DeSerialize(mgr, fp)) return false;
442  if (!fp->DeSerialize(&learning_iteration_)) {
443  // Special case. If we successfully decoded the recognizer, but fail here
444  // then it means we were just given a recognizer, so issue a warning and
445  // allow it.
446  tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
449  return true;
450  }
451  if (!fp->DeSerialize(&prev_sample_iteration_)) return false;
452  if (!fp->DeSerialize(&perfect_delay_)) return false;
453  if (!fp->DeSerialize(&last_perfect_training_iteration_)) return false;
454  for (auto & error_buffer : error_buffers_) {
455  if (!error_buffer.DeSerialize(fp)) return false;
456  }
457  if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) return false;
458  if (!fp->DeSerialize(&training_stage_)) return false;
459  uint8_t amount;
460  if (!fp->DeSerialize(&amount)) return false;
461  if (amount == LIGHT) return true; // Don't read the rest.
462  if (!fp->DeSerialize(&best_error_rate_)) return false;
463  if (!fp->DeSerialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
464  if (!fp->DeSerialize(&best_iteration_)) return false;
465  if (!fp->DeSerialize(&worst_error_rate_)) return false;
466  if (!fp->DeSerialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
467  if (!fp->DeSerialize(&worst_iteration_)) return false;
468  if (!fp->DeSerialize(&stall_iteration_)) return false;
469  if (!best_model_data_.DeSerialize(fp)) return false;
470  if (!worst_model_data_.DeSerialize(fp)) return false;
471  if (amount != NO_BEST_TRAINER && !best_trainer_.DeSerialize(fp)) return false;
472  GenericVector<char> sub_data;
473  if (!sub_data.DeSerialize(fp)) return false;
474  delete sub_trainer_;
475  if (sub_data.empty()) {
476  sub_trainer_ = nullptr;
477  } else {
478  sub_trainer_ = new LSTMTrainer();
479  if (!ReadTrainingDump(sub_data, sub_trainer_)) return false;
480  }
481  if (!best_error_history_.DeSerialize(fp)) return false;
482  if (!best_error_iterations_.DeSerialize(fp)) return false;
483  return fp->DeSerialize(&improvement_steps_);
484 }

◆ DisplayTargets()

void tesseract::LSTMTrainer::DisplayTargets ( const NetworkIO targets,
const char *  window_name,
ScrollView **  window 
)
protected

Definition at line 1042 of file lstmtrainer.cpp.

1043  {
1044 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics.
1045  int width = targets.Width();
1046  int num_features = targets.NumFeatures();
1047  Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale,
1048  window);
1049  for (int c = 0; c < num_features; ++c) {
1050  int color = c % (ScrollView::GREEN_YELLOW - 1) + 2;
1051  (*window)->Pen(static_cast<ScrollView::Color>(color));
1052  int start_t = -1;
1053  for (int t = 0; t < width; ++t) {
1054  double target = targets.f(t)[c];
1055  target *= kTargetYScale;
1056  if (target >= 1) {
1057  if (start_t < 0) {
1058  (*window)->SetCursor(t - 1, 0);
1059  start_t = t;
1060  }
1061  (*window)->DrawTo(t, target);
1062  } else if (start_t >= 0) {
1063  (*window)->DrawTo(t, 0);
1064  (*window)->DrawTo(start_t - 1, 0);
1065  start_t = -1;
1066  }
1067  }
1068  if (start_t >= 0) {
1069  (*window)->DrawTo(width, 0);
1070  (*window)->DrawTo(start_t - 1, 0);
1071  }
1072  }
1073  (*window)->Update();
1074 #endif // GRAPHICS_DISABLED
1075 }

◆ DumpFilename()

STRING tesseract::LSTMTrainer::DumpFilename ( ) const

Definition at line 914 of file lstmtrainer.cpp.

914  {
915  STRING filename;
916  filename += model_base_.c_str();
917  filename.add_str_double("_", best_error_rate_);
918  filename.add_str_int("_", best_iteration_);
919  filename.add_str_int("_", training_iteration_);
920  filename += ".checkpoint";
921  return filename;
922 }

◆ EmptyConstructor()

void tesseract::LSTMTrainer::EmptyConstructor ( )
protected

Definition at line 990 of file lstmtrainer.cpp.

990  {
991  align_win_ = nullptr;
992  target_win_ = nullptr;
993  ctc_win_ = nullptr;
994  recon_win_ = nullptr;
996  training_stage_ = 0;
998  InitIterations();
999 }

◆ EncodeString() [1/2]

bool tesseract::LSTMTrainer::EncodeString ( const STRING str,
const UNICHARSET unicharset,
const UnicharCompress recoder,
bool  simple_text,
int  null_char,
GenericVector< int > *  labels 
)
static

Definition at line 690 of file lstmtrainer.cpp.

692  {
693  if (str.c_str() == nullptr || str.length() <= 0) {
694  tprintf("Empty truth string!\n");
695  return false;
696  }
697  int err_index;
698  GenericVector<int> internal_labels;
699  labels->truncate(0);
700  if (!simple_text) labels->push_back(null_char);
701  std::string cleaned = unicharset.CleanupString(str.c_str());
702  if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, nullptr,
703  &err_index)) {
704  bool success = true;
705  for (int i = 0; i < internal_labels.size(); ++i) {
706  if (recoder != nullptr) {
707  // Re-encode labels via recoder.
708  RecodedCharID code;
709  int len = recoder->EncodeUnichar(internal_labels[i], &code);
710  if (len > 0) {
711  for (int j = 0; j < len; ++j) {
712  labels->push_back(code(j));
713  if (!simple_text) labels->push_back(null_char);
714  }
715  } else {
716  success = false;
717  err_index = 0;
718  break;
719  }
720  } else {
721  labels->push_back(internal_labels[i]);
722  if (!simple_text) labels->push_back(null_char);
723  }
724  }
725  if (success) return true;
726  }
727  tprintf("Encoding of string failed! Failure bytes:");
728  while (err_index < cleaned.size()) {
729  tprintf(" %x", cleaned[err_index++] & 0xff);
730  }
731  tprintf("\n");
732  return false;
733 }

◆ EncodeString() [2/2]

bool tesseract::LSTMTrainer::EncodeString ( const STRING str,
GenericVector< int > *  labels 
) const
inline

Definition at line 232 of file lstmtrainer.h.

232  {
233  return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : nullptr,
234  SimpleTextOutput(), null_char_, labels);
235  }

◆ error_rates()

const double* tesseract::LSTMTrainer::error_rates ( ) const
inline

Definition at line 126 of file lstmtrainer.h.

126  {
127  return error_rates_;
128  }

◆ FillErrorBuffer()

void tesseract::LSTMTrainer::FillErrorBuffer ( double  new_error,
ErrorTypes  type 
)

Definition at line 925 of file lstmtrainer.cpp.

925  {
926  for (int i = 0; i < kRollingBufferSize_; ++i)
927  error_buffers_[type][i] = new_error;
928  error_rates_[type] = 100.0 * new_error;
929 }

◆ GridSearchDictParams()

Trainability tesseract::LSTMTrainer::GridSearchDictParams ( const ImageData trainingdata,
int  iteration,
double  min_dict_ratio,
double  dict_ratio_step,
double  max_dict_ratio,
double  min_cert_offset,
double  cert_offset_step,
double  max_cert_offset,
STRING results 
)

Definition at line 215 of file lstmtrainer.cpp.

218  {
219  sample_iteration_ = iteration;
220  NetworkIO fwd_outputs, targets;
221  Trainability result =
222  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
223  if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == nullptr)
224  return result;
225 
226  // Encode/decode the truth to get the normalization.
227  GenericVector<int> truth_labels, ocr_labels, xcoords;
228  ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels));
229  // NO-dict error.
230  RecodeBeamSearch base_search(recoder_, null_char_, SimpleTextOutput(), nullptr);
231  base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty,
232  nullptr);
233  base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
234  STRING truth_text = DecodeLabels(truth_labels);
235  STRING ocr_text = DecodeLabels(ocr_labels);
236  double baseline_error = ComputeWordError(&truth_text, &ocr_text);
237  results->add_str_double("0,0=", baseline_error);
238 
239  RecodeBeamSearch search(recoder_, null_char_, SimpleTextOutput(), dict_);
240  for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
241  for (double c = min_cert_offset; c < max_cert_offset;
242  c += cert_offset_step) {
243  search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty, nullptr);
244  search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
245  truth_text = DecodeLabels(truth_labels);
246  ocr_text = DecodeLabels(ocr_labels);
247  // This is destructive on both strings.
248  double word_error = ComputeWordError(&truth_text, &ocr_text);
249  if ((r == min_dict_ratio && c == min_cert_offset) ||
250  !std::isfinite(word_error)) {
251  STRING t = DecodeLabels(truth_labels);
252  STRING o = DecodeLabels(ocr_labels);
253  tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
254  t.c_str(), o.c_str(), word_error, truth_labels[0]);
255  }
256  results->add_str_double(" ", r);
257  results->add_str_double(",", c);
258  results->add_str_double("=", word_error);
259  }
260  }
261  return result;
262 }

◆ improvement_steps()

int32_t tesseract::LSTMTrainer::improvement_steps ( ) const
inline

Definition at line 136 of file lstmtrainer.h.

136 { return improvement_steps_; }

◆ InitCharSet() [1/3]

void tesseract::LSTMTrainer::InitCharSet ( )
protected

Definition at line 968 of file lstmtrainer.cpp.

968  {
971  // Initialize the unicharset and recoder.
972  if (!LoadCharsets(&mgr_)) {
973  ASSERT_HOST(
974  "Must provide a traineddata containing lstm_unicharset and"
975  " lstm_recoder!\n" != nullptr);
976  }
977  SetNullChar();
978 }

◆ InitCharSet() [2/3]

void tesseract::LSTMTrainer::InitCharSet ( const std::string traineddata_path)
inline

Definition at line 95 of file lstmtrainer.h.

95  {
96  ASSERT_HOST(mgr_.Init(traineddata_path.c_str()));
97  InitCharSet();
98  }

◆ InitCharSet() [3/3]

void tesseract::LSTMTrainer::InitCharSet ( const TessdataManager mgr)
inline

Definition at line 99 of file lstmtrainer.h.

99  {
100  mgr_ = mgr;
101  InitCharSet();
102  }

◆ InitIterations()

void tesseract::LSTMTrainer::InitIterations ( )

Definition at line 190 of file lstmtrainer.cpp.

190  {
191  sample_iteration_ = 0;
195  best_error_rate_ = 100.0;
196  best_iteration_ = 0;
197  worst_error_rate_ = 0.0;
198  worst_iteration_ = 0;
201  perfect_delay_ = 0;
203  for (int i = 0; i < ET_COUNT; ++i) {
204  best_error_rates_[i] = 100.0;
205  worst_error_rates_[i] = 0.0;
207  error_rates_[i] = 100.0;
208  }
210 }

◆ InitNetwork()

bool tesseract::LSTMTrainer::InitNetwork ( const STRING network_spec,
int  append_index,
int  net_flags,
float  weight_range,
float  learning_rate,
float  momentum,
float  adam_beta 
)

Definition at line 146 of file lstmtrainer.cpp.

149  {
150  mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec.c_str());
151  adam_beta_ = adam_beta;
153  momentum_ = momentum;
154  SetNullChar();
155  if (!NetworkBuilder::InitNetwork(recoder_.code_range(), network_spec,
156  append_index, net_flags, weight_range,
157  &randomizer_, &network_)) {
158  return false;
159  }
160  network_str_ += network_spec;
161  tprintf("Built network:%s from request %s\n",
162  network_->spec().c_str(), network_spec.c_str());
163  tprintf(
164  "Training parameters:\n Debug interval = %d,"
165  " weights = %g, learning rate = %g, momentum=%g\n",
166  debug_interval_, weight_range, learning_rate_, momentum_);
167  tprintf("null char=%d\n", null_char_);
168  return true;
169 }

◆ InitTensorFlowNetwork()

int tesseract::LSTMTrainer::InitTensorFlowNetwork ( const std::string tf_proto)

◆ LastSingleError()

double tesseract::LSTMTrainer::LastSingleError ( ErrorTypes  type) const
inline

Definition at line 146 of file lstmtrainer.h.

146  {
147  return error_buffers_[type]
150  }

◆ learning_iteration()

int tesseract::LSTMTrainer::learning_iteration ( ) const
inline

Definition at line 135 of file lstmtrainer.h.

135 { return learning_iteration_; }

◆ LoadAllTrainingData()

bool tesseract::LSTMTrainer::LoadAllTrainingData ( const GenericVector< STRING > &  filenames,
CachingStrategy  cache_strategy,
bool  randomly_rotate 
)

Definition at line 272 of file lstmtrainer.cpp.

274  {
275  randomly_rotate_ = randomly_rotate;
277  return training_data_.LoadDocuments(filenames, cache_strategy,
279 }

◆ LogIterations()

void tesseract::LSTMTrainer::LogIterations ( const char *  intro_str,
STRING log_msg 
) const

Definition at line 384 of file lstmtrainer.cpp.

384  {
385  *log_msg += intro_str;
386  log_msg->add_str_int(" iteration ", learning_iteration());
387  log_msg->add_str_int("/", training_iteration());
388  log_msg->add_str_int("/", sample_iteration());
389 }

◆ MaintainCheckpoints()

bool tesseract::LSTMTrainer::MaintainCheckpoints ( TestCallback  tester,
STRING log_msg 
)

Definition at line 285 of file lstmtrainer.cpp.

285  {
286  PrepareLogMsg(log_msg);
287  double error_rate = CharError();
288  int iteration = learning_iteration();
289  if (iteration >= stall_iteration_ &&
290  error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) &&
292  // It hasn't got any better in a long while, and is a margin worse than the
293  // best, so go back to the best model and try a different learning rate.
294  StartSubtrainer(log_msg);
295  }
296  SubTrainerResult sub_trainer_result = STR_NONE;
297  if (sub_trainer_ != nullptr) {
298  sub_trainer_result = UpdateSubtrainer(log_msg);
299  if (sub_trainer_result == STR_REPLACED) {
300  // Reset the inputs, as we have overwritten *this.
301  error_rate = CharError();
302  iteration = learning_iteration();
303  PrepareLogMsg(log_msg);
304  }
305  }
306  bool result = true; // Something interesting happened.
307  GenericVector<char> rec_model_data;
308  if (error_rate < best_error_rate_) {
309  SaveRecognitionDump(&rec_model_data);
310  log_msg->add_str_double(" New best char error = ", error_rate);
311  *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
312  // If sub_trainer_ is not nullptr, either *this beat it to a new best, or it
313  // just overwrote *this. In either case, we have finished with it.
314  delete sub_trainer_;
315  sub_trainer_ = nullptr;
318  log_msg->add_str_int(" Transitioned to stage ", CurrentTrainingStage());
319  }
322  STRING best_model_name = DumpFilename();
323  if (!SaveDataToFile(best_trainer_, best_model_name.c_str())) {
324  *log_msg += " failed to write best model:";
325  } else {
326  *log_msg += " wrote best model:";
328  }
329  *log_msg += best_model_name;
330  }
331  } else if (error_rate > worst_error_rate_) {
332  SaveRecognitionDump(&rec_model_data);
333  log_msg->add_str_double(" New worst char error = ", error_rate);
334  *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
337  // Error rate has ballooned. Go back to the best model.
338  *log_msg += "\nDivergence! ";
339  // Copy best_trainer_ before reading it, as it will get overwritten.
340  GenericVector<char> revert_data(best_trainer_);
341  if (ReadTrainingDump(revert_data, this)) {
342  LogIterations("Reverted to", log_msg);
343  ReduceLearningRates(this, log_msg);
344  } else {
345  LogIterations("Failed to Revert at", log_msg);
346  }
347  // If it fails again, we will wait twice as long before reverting again.
348  stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
349  // Re-save the best trainer with the new learning rates and stall
350  // iteration.
352  }
353  } else {
354  // Something interesting happened only if the sub_trainer_ was trained.
355  result = sub_trainer_result != STR_NONE;
356  }
357  if (checkpoint_name_.length() > 0) {
358  // Write a current checkpoint.
359  GenericVector<char> checkpoint;
360  if (!SaveTrainingDump(FULL, this, &checkpoint) ||
361  !SaveDataToFile(checkpoint, checkpoint_name_.c_str())) {
362  *log_msg += " failed to write checkpoint.";
363  } else {
364  *log_msg += " wrote checkpoint.";
365  }
366  }
367  *log_msg += "\n";
368  return result;
369 }

◆ MaintainCheckpointsSpecific()

bool tesseract::LSTMTrainer::MaintainCheckpointsSpecific ( int  iteration,
const GenericVector< char > *  train_model,
const GenericVector< char > *  rec_model,
TestCallback  tester,
STRING log_msg 
)

◆ MapRecoder()

std::vector< int > tesseract::LSTMTrainer::MapRecoder ( const UNICHARSET old_chset,
const UnicharCompress old_recoder 
) const

Definition at line 933 of file lstmtrainer.cpp.

934  {
935  int num_new_codes = recoder_.code_range();
936  int num_new_unichars = GetUnicharset().size();
937  std::vector<int> code_map(num_new_codes, -1);
938  for (int c = 0; c < num_new_codes; ++c) {
939  int old_code = -1;
940  // Find all new unichar_ids that recode to something that includes c.
941  // The <= is to include the null char, which may be beyond the unicharset.
942  for (int uid = 0; uid <= num_new_unichars; ++uid) {
943  RecodedCharID codes;
944  int length = recoder_.EncodeUnichar(uid, &codes);
945  int code_index = 0;
946  while (code_index < length && codes(code_index) != c) ++code_index;
947  if (code_index == length) continue;
948  // The old unicharset must have the same unichar.
949  int old_uid =
950  uid < num_new_unichars
951  ? old_chset.unichar_to_id(GetUnicharset().id_to_unichar(uid))
952  : old_chset.size() - 1;
953  if (old_uid == INVALID_UNICHAR_ID) continue;
954  // The encoding of old_uid at the same code_index is the old code.
955  RecodedCharID old_codes;
956  if (code_index < old_recoder.EncodeUnichar(old_uid, &old_codes)) {
957  old_code = old_codes(code_index);
958  break;
959  }
960  }
961  code_map[c] = old_code;
962  }
963  return code_map;
964 }

◆ mutable_training_data()

DocumentCache* tesseract::LSTMTrainer::mutable_training_data ( )
inline

Definition at line 154 of file lstmtrainer.h.

154 { return &training_data_; }

◆ NewSingleError()

double tesseract::LSTMTrainer::NewSingleError ( ErrorTypes  type) const
inline

Definition at line 140 of file lstmtrainer.h.

140  {
142  }

◆ PrepareForBackward()

Trainability tesseract::LSTMTrainer::PrepareForBackward ( const ImageData trainingdata,
NetworkIO fwd_outputs,
NetworkIO targets 
)

Definition at line 770 of file lstmtrainer.cpp.

772  {
773  if (trainingdata == nullptr) {
774  tprintf("Null trainingdata.\n");
775  return UNENCODABLE;
776  }
777  // Ensure repeatability of random elements even across checkpoints.
778  bool debug = debug_interval_ > 0 &&
780  GenericVector<int> truth_labels;
781  if (!EncodeString(trainingdata->transcription(), &truth_labels)) {
782  tprintf("Can't encode transcription: '%s' in language '%s'\n",
783  trainingdata->transcription().c_str(),
784  trainingdata->language().c_str());
785  return UNENCODABLE;
786  }
787  bool upside_down = false;
788  if (randomly_rotate_) {
789  // This ensures consistent training results.
790  SetRandomSeed();
791  upside_down = randomizer_.SignedRand(1.0) > 0.0;
792  if (upside_down) {
793  // Modify the truth labels to match the rotation:
794  // Apart from space and null, increment the label. This is changes the
795  // script-id to the same script-id but upside-down.
796  // The labels need to be reversed in order, as the first is now the last.
797  for (int c = 0; c < truth_labels.size(); ++c) {
798  if (truth_labels[c] != UNICHAR_SPACE && truth_labels[c] != null_char_)
799  ++truth_labels[c];
800  }
801  truth_labels.reverse();
802  }
803  }
804  int w = 0;
805  while (w < truth_labels.size() &&
806  (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_))
807  ++w;
808  if (w == truth_labels.size()) {
809  tprintf("Blank transcription: %s\n",
810  trainingdata->transcription().c_str());
811  return UNENCODABLE;
812  }
813  float image_scale;
814  NetworkIO inputs;
815  bool invert = trainingdata->boxes().empty();
816  if (!RecognizeLine(*trainingdata, invert, debug, invert, upside_down,
817  &image_scale, &inputs, fwd_outputs)) {
818  tprintf("Image not trainable\n");
819  return UNENCODABLE;
820  }
821  targets->Resize(*fwd_outputs, network_->NumOutputs());
822  LossType loss_type = OutputLossType();
823  if (loss_type == LT_SOFTMAX) {
824  if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) {
825  tprintf("Compute simple targets failed!\n");
826  return UNENCODABLE;
827  }
828  } else if (loss_type == LT_CTC) {
829  if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) {
830  tprintf("Compute CTC targets failed!\n");
831  return UNENCODABLE;
832  }
833  } else {
834  tprintf("Logistic outputs not implemented yet!\n");
835  return UNENCODABLE;
836  }
837  GenericVector<int> ocr_labels;
838  GenericVector<int> xcoords;
839  LabelsFromOutputs(*fwd_outputs, &ocr_labels, &xcoords);
840  // CTC does not produce correct target labels to begin with.
841  if (loss_type != LT_CTC) {
842  LabelsFromOutputs(*targets, &truth_labels, &xcoords);
843  }
844  if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels,
845  *targets)) {
846  tprintf("Input width was %d\n", inputs.Width());
847  return UNENCODABLE;
848  }
849  STRING ocr_text = DecodeLabels(ocr_labels);
850  STRING truth_text = DecodeLabels(truth_labels);
851  targets->SubtractAllFromFloat(*fwd_outputs);
852  if (debug_interval_ != 0) {
853  if (truth_text != ocr_text) {
854  tprintf("Iteration %d: BEST OCR TEXT : %s\n",
855  training_iteration(), ocr_text.c_str());
856  }
857  }
858  double char_error = ComputeCharError(truth_labels, ocr_labels);
859  double word_error = ComputeWordError(&truth_text, &ocr_text);
860  double delta_error = ComputeErrorRates(*targets, char_error, word_error);
861  if (debug_interval_ != 0) {
862  tprintf("File %s line %d %s:\n", trainingdata->imagefilename().c_str(),
863  trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : "");
864  }
865  if (delta_error == 0.0) return PERFECT;
866  if (targets->AnySuspiciousTruth(kHighConfidence)) return HI_PRECISION_ERR;
867  return TRAINABLE;
868 }

◆ PrepareLogMsg()

void tesseract::LSTMTrainer::PrepareLogMsg ( STRING log_msg) const

Definition at line 372 of file lstmtrainer.cpp.

372  {
373  LogIterations("At", log_msg);
374  log_msg->add_str_double(", Mean rms=", error_rates_[ET_RMS]);
375  log_msg->add_str_double("%, delta=", error_rates_[ET_DELTA]);
376  log_msg->add_str_double("%, char train=", error_rates_[ET_CHAR_ERROR]);
377  log_msg->add_str_double("%, word train=", error_rates_[ET_WORD_RECERR]);
378  log_msg->add_str_double("%, skip ratio=", error_rates_[ET_SKIP_RATIO]);
379  *log_msg += "%, ";
380 }

◆ ReadLocalTrainingDump()

bool tesseract::LSTMTrainer::ReadLocalTrainingDump ( const TessdataManager mgr,
const char *  data,
int  size 
)

Definition at line 883 of file lstmtrainer.cpp.

884  {
885  if (size == 0) {
886  tprintf("Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n");
887  return false;
888  }
889  TFile fp;
890  fp.Open(data, size);
891  return DeSerialize(mgr, &fp);
892 }

◆ ReadSizedTrainingDump()

bool tesseract::LSTMTrainer::ReadSizedTrainingDump ( const char *  data,
int  size,
LSTMTrainer trainer 
) const
inline

Definition at line 282 of file lstmtrainer.h.

283  {
284  return trainer->ReadLocalTrainingDump(&mgr_, data, size);
285  }

◆ ReadTrainingDump()

bool tesseract::LSTMTrainer::ReadTrainingDump ( const GenericVector< char > &  data,
LSTMTrainer trainer 
) const
inline

Definition at line 277 of file lstmtrainer.h.

278  {
279  if (data.empty()) return false;
280  return ReadSizedTrainingDump(&data[0], data.size(), trainer);
281  }

◆ ReduceLayerLearningRates()

int tesseract::LSTMTrainer::ReduceLayerLearningRates ( double  factor,
int  num_samples,
LSTMTrainer samples_trainer 
)

Definition at line 581 of file lstmtrainer.cpp.

582  {
583  enum WhichWay {
584  LR_DOWN, // Learning rate will go down by factor.
585  LR_SAME, // Learning rate will stay the same.
586  LR_COUNT // Size of arrays.
587  };
589  int num_layers = layers.size();
590  GenericVector<int> num_weights;
591  num_weights.init_to_size(num_layers, 0);
592  GenericVector<double> bad_sums[LR_COUNT];
593  GenericVector<double> ok_sums[LR_COUNT];
594  for (int i = 0; i < LR_COUNT; ++i) {
595  bad_sums[i].init_to_size(num_layers, 0.0);
596  ok_sums[i].init_to_size(num_layers, 0.0);
597  }
598  double momentum_factor = 1.0 / (1.0 - momentum_);
599  GenericVector<char> orig_trainer;
600  samples_trainer->SaveTrainingDump(LIGHT, this, &orig_trainer);
601  for (int i = 0; i < num_layers; ++i) {
602  Network* layer = GetLayer(layers[i]);
603  num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
604  }
605  int iteration = sample_iteration();
606  for (int s = 0; s < num_samples; ++s) {
607  // Which way will we modify the learning rate?
608  for (int ww = 0; ww < LR_COUNT; ++ww) {
609  // Transfer momentum to learning rate and adjust by the ww factor.
610  float ww_factor = momentum_factor;
611  if (ww == LR_DOWN) ww_factor *= factor;
612  // Make a copy of *this, so we can mess about without damaging anything.
613  LSTMTrainer copy_trainer;
614  samples_trainer->ReadTrainingDump(orig_trainer, &copy_trainer);
615  // Clear the updates, doing nothing else.
616  copy_trainer.network_->Update(0.0, 0.0, 0.0, 0);
617  // Adjust the learning rate in each layer.
618  for (int i = 0; i < num_layers; ++i) {
619  if (num_weights[i] == 0) continue;
620  copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor);
621  }
622  copy_trainer.SetIteration(iteration);
623  // Train on the sample, but keep the update in updates_ instead of
624  // applying to the weights.
625  const ImageData* trainingdata =
626  copy_trainer.TrainOnLine(samples_trainer, true);
627  if (trainingdata == nullptr) continue;
628  // We'll now use this trainer again for each layer.
629  GenericVector<char> updated_trainer;
630  samples_trainer->SaveTrainingDump(LIGHT, &copy_trainer, &updated_trainer);
631  for (int i = 0; i < num_layers; ++i) {
632  if (num_weights[i] == 0) continue;
633  LSTMTrainer layer_trainer;
634  samples_trainer->ReadTrainingDump(updated_trainer, &layer_trainer);
635  Network* layer = layer_trainer.GetLayer(layers[i]);
636  // Update the weights in just the layer, using Adam if enabled.
637  layer->Update(0.0, momentum_, adam_beta_,
638  layer_trainer.training_iteration_ + 1);
639  // Zero the updates matrix again.
640  layer->Update(0.0, 0.0, 0.0, 0);
641  // Train again on the same sample, again holding back the updates.
642  layer_trainer.TrainOnLine(trainingdata, true);
643  // Count the sign changes in the updates in layer vs in copy_trainer.
644  float before_bad = bad_sums[ww][i];
645  float before_ok = ok_sums[ww][i];
646  layer->CountAlternators(*copy_trainer.GetLayer(layers[i]),
647  &ok_sums[ww][i], &bad_sums[ww][i]);
648  float bad_frac =
649  bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
650  if (bad_frac > 0.0f)
651  bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
652  }
653  }
654  ++iteration;
655  }
656  int num_lowered = 0;
657  for (int i = 0; i < num_layers; ++i) {
658  if (num_weights[i] == 0) continue;
659  Network* layer = GetLayer(layers[i]);
660  float lr = GetLayerLearningRate(layers[i]);
661  double total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
662  double total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
663  double frac_down = bad_sums[LR_DOWN][i] / total_down;
664  double frac_same = bad_sums[LR_SAME][i] / total_same;
665  tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().c_str(),
666  lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
667  if (frac_down < frac_same * kImprovementFraction) {
668  tprintf(" REDUCED\n");
669  ScaleLayerLearningRate(layers[i], factor);
670  ++num_lowered;
671  } else {
672  tprintf(" SAME\n");
673  }
674  }
675  if (num_lowered == 0) {
676  // Just lower everything to make sure.
677  for (int i = 0; i < num_layers; ++i) {
678  if (num_weights[i] > 0) {
679  ScaleLayerLearningRate(layers[i], factor);
680  ++num_lowered;
681  }
682  }
683  }
684  return num_lowered;
685 }

◆ ReduceLearningRates()

void tesseract::LSTMTrainer::ReduceLearningRates ( LSTMTrainer samples_trainer,
STRING log_msg 
)

Definition at line 562 of file lstmtrainer.cpp.

563  {
565  int num_reduced = ReduceLayerLearningRates(
566  kLearningRateDecay, kNumAdjustmentIterations, samples_trainer);
567  log_msg->add_str_int("\nReduced learning rate on layers: ", num_reduced);
568  } else {
570  log_msg->add_str_double("\nReduced learning rate to :", learning_rate_);
571  }
572  *log_msg += "\n";
573 }

◆ RollErrorBuffers()

void tesseract::LSTMTrainer::RollErrorBuffers ( )
protected

Definition at line 1241 of file lstmtrainer.cpp.

1241  {
1243  if (NewSingleError(ET_DELTA) > 0.0)
1245  else
1248  if (debug_interval_ != 0) {
1249  tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1253  }
1254 }

◆ SaveRecognitionDump()

void tesseract::LSTMTrainer::SaveRecognitionDump ( GenericVector< char > *  data) const

Definition at line 904 of file lstmtrainer.cpp.

904  {
905  TFile fp;
906  fp.OpenWrite(data);
910 }

◆ SaveTraineddata()

bool tesseract::LSTMTrainer::SaveTraineddata ( const STRING filename)

Definition at line 895 of file lstmtrainer.cpp.

895  {
896  GenericVector<char> recognizer_data;
897  SaveRecognitionDump(&recognizer_data);
898  mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
899  recognizer_data.size());
900  return mgr_.SaveFile(filename, SaveDataToFile);
901 }

◆ SaveTrainingDump()

bool tesseract::LSTMTrainer::SaveTrainingDump ( SerializeAmount  serialize_amount,
const LSTMTrainer trainer,
GenericVector< char > *  data 
) const

Definition at line 874 of file lstmtrainer.cpp.

876  {
877  TFile fp;
878  fp.OpenWrite(data);
879  return trainer->Serialize(serialize_amount, &mgr_, &fp);
880 }

◆ Serialize()

bool tesseract::LSTMTrainer::Serialize ( SerializeAmount  serialize_amount,
const TessdataManager mgr,
TFile fp 
) const

Definition at line 403 of file lstmtrainer.cpp.

404  {
405  if (!LSTMRecognizer::Serialize(mgr, fp)) return false;
406  if (!fp->Serialize(&learning_iteration_)) return false;
407  if (!fp->Serialize(&prev_sample_iteration_)) return false;
408  if (!fp->Serialize(&perfect_delay_)) return false;
409  if (!fp->Serialize(&last_perfect_training_iteration_)) return false;
410  for (const auto & error_buffer : error_buffers_) {
411  if (!error_buffer.Serialize(fp)) return false;
412  }
413  if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) return false;
414  if (!fp->Serialize(&training_stage_)) return false;
415  uint8_t amount = serialize_amount;
416  if (!fp->Serialize(&amount)) return false;
417  if (serialize_amount == LIGHT) return true; // We are done.
418  if (!fp->Serialize(&best_error_rate_)) return false;
419  if (!fp->Serialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
420  if (!fp->Serialize(&best_iteration_)) return false;
421  if (!fp->Serialize(&worst_error_rate_)) return false;
422  if (!fp->Serialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
423  if (!fp->Serialize(&worst_iteration_)) return false;
424  if (!fp->Serialize(&stall_iteration_)) return false;
425  if (!best_model_data_.Serialize(fp)) return false;
426  if (!worst_model_data_.Serialize(fp)) return false;
427  if (serialize_amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp))
428  return false;
429  GenericVector<char> sub_data;
430  if (sub_trainer_ != nullptr && !SaveTrainingDump(LIGHT, sub_trainer_, &sub_data))
431  return false;
432  if (!sub_data.Serialize(fp)) return false;
433  if (!best_error_history_.Serialize(fp)) return false;
434  if (!best_error_iterations_.Serialize(fp)) return false;
435  return fp->Serialize(&improvement_steps_);
436 }

◆ set_perfect_delay()

void tesseract::LSTMTrainer::set_perfect_delay ( int  delay)
inline

Definition at line 137 of file lstmtrainer.h.

137 { perfect_delay_ = delay; }

◆ SetNullChar()

void tesseract::LSTMTrainer::SetNullChar ( )
protected

Definition at line 981 of file lstmtrainer.cpp.

981  {
983  : GetUnicharset().size();
984  RecodedCharID code;
986  null_char_ = code(0);
987 }

◆ SetupCheckpointInfo()

void tesseract::LSTMTrainer::SetupCheckpointInfo ( )

◆ StartSubtrainer()

void tesseract::LSTMTrainer::StartSubtrainer ( STRING log_msg)

Definition at line 489 of file lstmtrainer.cpp.

489  {
490  delete sub_trainer_;
491  sub_trainer_ = new LSTMTrainer();
493  *log_msg += " Failed to revert to previous best for trial!";
494  delete sub_trainer_;
495  sub_trainer_ = nullptr;
496  } else {
497  log_msg->add_str_int(" Trial sub_trainer_ from iteration ",
499  // Reduce learning rate so it doesn't diverge this time.
500  sub_trainer_->ReduceLearningRates(this, log_msg);
501  // If it fails again, we will wait twice as long before reverting again.
502  int stall_offset =
504  stall_iteration_ = learning_iteration() + 2 * stall_offset;
506  // Re-save the best trainer with the new learning rates and stall iteration.
508  }
509 }

◆ training_data()

const DocumentCache& tesseract::LSTMTrainer::training_data ( ) const
inline

Definition at line 151 of file lstmtrainer.h.

151  {
152  return training_data_;
153  }

◆ TrainOnLine() [1/2]

Trainability tesseract::LSTMTrainer::TrainOnLine ( const ImageData trainingdata,
bool  batch 
)

Definition at line 737 of file lstmtrainer.cpp.

738  {
739  NetworkIO fwd_outputs, targets;
740  Trainability trainable =
741  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
743  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
744  return trainable; // Sample was unusable.
745  }
746  bool debug = debug_interval_ > 0 &&
748  // Run backprop on the output.
749  NetworkIO bp_deltas;
750  if (network_->IsTraining() &&
751  (trainable != PERFECT ||
754  network_->Backward(debug, targets, &scratch_space_, &bp_deltas);
756  training_iteration_ + 1);
757  }
758 #ifndef GRAPHICS_DISABLED
759  if (debug_interval_ == 1 && debug_win_ != nullptr) {
761  }
762 #endif // GRAPHICS_DISABLED
763  // Roll the memory of past means.
765  return trainable;
766 }

◆ TrainOnLine() [2/2]

const ImageData* tesseract::LSTMTrainer::TrainOnLine ( LSTMTrainer samples_trainer,
bool  batch 
)
inline

Definition at line 245 of file lstmtrainer.h.

245  {
246  int sample_index = sample_iteration();
247  const ImageData* image =
248  samples_trainer->training_data_.GetPageBySerial(sample_index);
249  if (image != nullptr) {
250  Trainability trainable = TrainOnLine(image, batch);
251  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
252  return nullptr; // Sample was unusable.
253  }
254  } else {
256  }
257  return image;
258  }

◆ TransitionTrainingStage()

bool tesseract::LSTMTrainer::TransitionTrainingStage ( float  error_threshold)

Definition at line 393 of file lstmtrainer.cpp.

393  {
394  if (best_error_rate_ < error_threshold &&
396  ++training_stage_;
397  return true;
398  }
399  return false;
400 }

◆ TryLoadingCheckpoint()

bool tesseract::LSTMTrainer::TryLoadingCheckpoint ( const char *  filename,
const char *  old_traineddata 
)

Definition at line 103 of file lstmtrainer.cpp.

104  {
105  GenericVector<char> data;
106  if (!LoadDataFromFile(filename, &data)) return false;
107  tprintf("Loaded file %s, unpacking...\n", filename);
108  if (!ReadTrainingDump(data, this)) return false;
109  StaticShape shape = network_->OutputShape(network_->InputShape());
110  if (((old_traineddata == nullptr || *old_traineddata == '\0') &&
112  filename == old_traineddata) {
113  return true; // Normal checkpoint load complete.
114  }
115  tprintf("Code range changed from %d to %d!\n", network_->NumOutputs(),
116  recoder_.code_range());
117  if (old_traineddata == nullptr || *old_traineddata == '\0') {
118  tprintf("Must supply the old traineddata for code conversion!\n");
119  return false;
120  }
121  TessdataManager old_mgr;
122  ASSERT_HOST(old_mgr.Init(old_traineddata));
123  TFile fp;
124  if (!old_mgr.GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) return false;
125  UNICHARSET old_chset;
126  if (!old_chset.load_from_file(&fp, false)) return false;
127  if (!old_mgr.GetComponent(TESSDATA_LSTM_RECODER, &fp)) return false;
128  UnicharCompress old_recoder;
129  if (!old_recoder.DeSerialize(&fp)) return false;
130  std::vector<int> code_map = MapRecoder(old_chset, old_recoder);
131  // Set the null_char_ to the new value.
132  int old_null_char = null_char_;
133  SetNullChar();
134  // Map the softmax(s) in the network.
135  network_->RemapOutputs(old_recoder.code_range(), code_map);
136  tprintf("Previous null char=%d mapped to %d\n", old_null_char, null_char_);
137  return true;
138 }

◆ UpdateErrorBuffer()

void tesseract::LSTMTrainer::UpdateErrorBuffer ( double  new_error,
ErrorTypes  type 
)
protected

Definition at line 1228 of file lstmtrainer.cpp.

1228  {
1230  error_buffers_[type][index] = new_error;
1231  // Compute the mean error.
1232  int mean_count = std::min(training_iteration_ + 1, error_buffers_[type].size());
1233  double buffer_sum = 0.0;
1234  for (int i = 0; i < mean_count; ++i) buffer_sum += error_buffers_[type][i];
1235  double mean = buffer_sum / mean_count;
1236  // Trim precision to 1/1000 of 1%.
1237  error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0;
1238 }

◆ UpdateErrorGraph()

STRING tesseract::LSTMTrainer::UpdateErrorGraph ( int  iteration,
double  error_rate,
const GenericVector< char > &  model_data,
TestCallback  tester 
)
protected

Definition at line 1260 of file lstmtrainer.cpp.

1262  {
1263  if (error_rate > best_error_rate_
1264  && iteration < best_iteration_ + kErrorGraphInterval) {
1265  // Too soon to record a new point.
1266  if (tester != nullptr && !worst_model_data_.empty()) {
1269  return tester(worst_iteration_, nullptr, mgr_, CurrentTrainingStage());
1270  } else {
1271  return "";
1272  }
1273  }
1274  STRING result;
1275  // NOTE: there are 2 asymmetries here:
1276  // 1. We are computing the global minimum, but the local maximum in between.
1277  // 2. If the tester returns an empty string, indicating that it is busy,
1278  // call it repeatedly on new local maxima to test the previous min, but
1279  // not the other way around, as there is little point testing the maxima
1280  // between very frequent minima.
1281  if (error_rate < best_error_rate_) {
1282  // This is a new (global) minimum.
1283  if (tester != nullptr && !worst_model_data_.empty()) {
1286  result = tester(worst_iteration_, worst_error_rates_, mgr_,
1289  best_model_data_ = model_data;
1290  }
1291  best_error_rate_ = error_rate;
1292  memcpy(best_error_rates_, error_rates_, sizeof(error_rates_));
1293  best_iteration_ = iteration;
1294  best_error_history_.push_back(error_rate);
1295  best_error_iterations_.push_back(iteration);
1296  // Compute 2% decay time.
1297  double two_percent_more = error_rate + 2.0;
1298  int i;
1299  for (i = best_error_history_.size() - 1;
1300  i >= 0 && best_error_history_[i] < two_percent_more; --i) {
1301  }
1302  int old_iteration = i >= 0 ? best_error_iterations_[i] : 0;
1303  improvement_steps_ = iteration - old_iteration;
1304  tprintf("2 Percent improvement time=%d, best error was %g @ %d\n",
1305  improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0,
1306  old_iteration);
1307  } else if (error_rate > best_error_rate_) {
1308  // This is a new (local) maximum.
1309  if (tester != nullptr) {
1310  if (!best_model_data_.empty()) {
1313  result = tester(best_iteration_, best_error_rates_, mgr_,
1315  } else if (!worst_model_data_.empty()) {
1316  // Allow for multiple data points with "worst" error rate.
1319  result = tester(worst_iteration_, worst_error_rates_, mgr_,
1321  }
1322  if (result.length() > 0)
1324  worst_model_data_ = model_data;
1325  }
1326  }
1327  worst_error_rate_ = error_rate;
1328  memcpy(worst_error_rates_, error_rates_, sizeof(error_rates_));
1329  worst_iteration_ = iteration;
1330  return result;
1331 }

◆ UpdateSubtrainer()

SubTrainerResult tesseract::LSTMTrainer::UpdateSubtrainer ( STRING log_msg)

Definition at line 519 of file lstmtrainer.cpp.

519  {
520  double training_error = CharError();
521  double sub_error = sub_trainer_->CharError();
522  double sub_margin = (training_error - sub_error) / sub_error;
523  if (sub_margin >= kSubTrainerMarginFraction) {
524  log_msg->add_str_double(" sub_trainer=", sub_error);
525  log_msg->add_str_double(" margin=", 100.0 * sub_margin);
526  *log_msg += "\n";
527  // Catch up to current iteration.
528  int end_iteration = training_iteration();
529  while (sub_trainer_->training_iteration() < end_iteration &&
530  sub_margin >= kSubTrainerMarginFraction) {
531  int target_iteration =
533  while (sub_trainer_->training_iteration() < target_iteration) {
534  sub_trainer_->TrainOnLine(this, false);
535  }
536  STRING batch_log = "Sub:";
537  sub_trainer_->PrepareLogMsg(&batch_log);
538  batch_log += "\n";
539  tprintf("UpdateSubtrainer:%s", batch_log.c_str());
540  *log_msg += batch_log;
541  sub_error = sub_trainer_->CharError();
542  sub_margin = (training_error - sub_error) / sub_error;
543  }
544  if (sub_error < best_error_rate_ &&
545  sub_margin >= kSubTrainerMarginFraction) {
546  // The sub_trainer_ has won the race to a new best. Switch to it.
547  GenericVector<char> updated_trainer;
548  SaveTrainingDump(LIGHT, sub_trainer_, &updated_trainer);
549  ReadTrainingDump(updated_trainer, this);
550  log_msg->add_str_int(" Sub trainer wins at iteration ",
552  *log_msg += "\n";
553  return STR_REPLACED;
554  }
555  return STR_UPDATED;
556  }
557  return STR_NONE;
558 }

Member Data Documentation

◆ align_win_

ScrollView* tesseract::LSTMTrainer::align_win_
protected

Definition at line 383 of file lstmtrainer.h.

◆ best_error_history_

GenericVector<double> tesseract::LSTMTrainer::best_error_history_
protected

Definition at line 436 of file lstmtrainer.h.

◆ best_error_iterations_

GenericVector<int> tesseract::LSTMTrainer::best_error_iterations_
protected

Definition at line 437 of file lstmtrainer.h.

◆ best_error_rate_

double tesseract::LSTMTrainer::best_error_rate_
protected

Definition at line 409 of file lstmtrainer.h.

◆ best_error_rates_

double tesseract::LSTMTrainer::best_error_rates_[ET_COUNT]
protected

Definition at line 411 of file lstmtrainer.h.

◆ best_iteration_

int tesseract::LSTMTrainer::best_iteration_
protected

Definition at line 413 of file lstmtrainer.h.

◆ best_model_data_

GenericVector<char> tesseract::LSTMTrainer::best_model_data_
protected

Definition at line 423 of file lstmtrainer.h.

◆ best_model_name_

STRING tesseract::LSTMTrainer::best_model_name_
protected

Definition at line 402 of file lstmtrainer.h.

◆ best_trainer_

GenericVector<char> tesseract::LSTMTrainer::best_trainer_
protected

Definition at line 426 of file lstmtrainer.h.

◆ checkpoint_iteration_

int tesseract::LSTMTrainer::checkpoint_iteration_
protected

Definition at line 393 of file lstmtrainer.h.

◆ checkpoint_name_

STRING tesseract::LSTMTrainer::checkpoint_name_
protected

Definition at line 397 of file lstmtrainer.h.

◆ ctc_win_

ScrollView* tesseract::LSTMTrainer::ctc_win_
protected

Definition at line 387 of file lstmtrainer.h.

◆ debug_interval_

int tesseract::LSTMTrainer::debug_interval_
protected

Definition at line 391 of file lstmtrainer.h.

◆ error_buffers_

GenericVector<double> tesseract::LSTMTrainer::error_buffers_[ET_COUNT]
protected

Definition at line 458 of file lstmtrainer.h.

◆ error_rate_of_last_saved_best_

float tesseract::LSTMTrainer::error_rate_of_last_saved_best_
protected

Definition at line 431 of file lstmtrainer.h.

◆ error_rates_

double tesseract::LSTMTrainer::error_rates_[ET_COUNT]
protected

Definition at line 460 of file lstmtrainer.h.

◆ improvement_steps_

int32_t tesseract::LSTMTrainer::improvement_steps_
protected

Definition at line 439 of file lstmtrainer.h.

◆ kRollingBufferSize_

const int tesseract::LSTMTrainer::kRollingBufferSize_ = 1000
staticprotected

Definition at line 457 of file lstmtrainer.h.

◆ last_perfect_training_iteration_

int tesseract::LSTMTrainer::last_perfect_training_iteration_
protected

Definition at line 454 of file lstmtrainer.h.

◆ learning_iteration_

int tesseract::LSTMTrainer::learning_iteration_
protected

Definition at line 443 of file lstmtrainer.h.

◆ mgr_

TessdataManager tesseract::LSTMTrainer::mgr_
protected

Definition at line 462 of file lstmtrainer.h.

◆ model_base_

STRING tesseract::LSTMTrainer::model_base_
protected

Definition at line 395 of file lstmtrainer.h.

◆ num_training_stages_

int tesseract::LSTMTrainer::num_training_stages_
protected

Definition at line 404 of file lstmtrainer.h.

◆ perfect_delay_

int tesseract::LSTMTrainer::perfect_delay_
protected

Definition at line 451 of file lstmtrainer.h.

◆ prev_sample_iteration_

int tesseract::LSTMTrainer::prev_sample_iteration_
protected

Definition at line 445 of file lstmtrainer.h.

◆ randomly_rotate_

bool tesseract::LSTMTrainer::randomly_rotate_
protected

Definition at line 399 of file lstmtrainer.h.

◆ recon_win_

ScrollView* tesseract::LSTMTrainer::recon_win_
protected

Definition at line 389 of file lstmtrainer.h.

◆ stall_iteration_

int tesseract::LSTMTrainer::stall_iteration_
protected

Definition at line 421 of file lstmtrainer.h.

◆ sub_trainer_

LSTMTrainer* tesseract::LSTMTrainer::sub_trainer_
protected

Definition at line 429 of file lstmtrainer.h.

◆ target_win_

ScrollView* tesseract::LSTMTrainer::target_win_
protected

Definition at line 385 of file lstmtrainer.h.

◆ training_data_

DocumentCache tesseract::LSTMTrainer::training_data_
protected

Definition at line 400 of file lstmtrainer.h.

◆ training_stage_

int tesseract::LSTMTrainer::training_stage_
protected

Definition at line 433 of file lstmtrainer.h.

◆ worst_error_rate_

double tesseract::LSTMTrainer::worst_error_rate_
protected

Definition at line 415 of file lstmtrainer.h.

◆ worst_error_rates_

double tesseract::LSTMTrainer::worst_error_rates_[ET_COUNT]
protected

Definition at line 417 of file lstmtrainer.h.

◆ worst_iteration_

int tesseract::LSTMTrainer::worst_iteration_
protected

Definition at line 419 of file lstmtrainer.h.

◆ worst_model_data_

GenericVector<char> tesseract::LSTMTrainer::worst_model_data_
protected

Definition at line 424 of file lstmtrainer.h.


The documentation for this class was generated from the following files:
UNICHARSET::load_from_file
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:378
tesseract::ET_COUNT
Definition: lstmtrainer.h:43
string
std::string string
Definition: equationdetect_test.cc:21
tesseract::LSTMTrainer::DebugLSTMTraining
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
Definition: lstmtrainer.cpp:1005
tesseract::TS_ENABLED
Definition: network.h:95
tesseract::RecodeBeamSearch::kMinCertainty
static constexpr float kMinCertainty
Definition: recodebeam.h:252
tesseract::kSubTrainerMarginFraction
const double kSubTrainerMarginFraction
Definition: lstmtrainer.cpp:51
tesseract::LSTMRecognizer::learning_rate_
float learning_rate_
Definition: lstmrecognizer.h:283
tesseract::STR_NONE
Definition: lstmtrainer.h:64
tesseract::LSTMTrainer::randomly_rotate_
bool randomly_rotate_
Definition: lstmtrainer.h:399
tesseract::STR_REPLACED
Definition: lstmtrainer.h:66
tesseract::kMinDivergenceRate
const double kMinDivergenceRate
Definition: lstmtrainer.cpp:46
tesseract::LSTMRecognizer::dict_
Dict * dict_
Definition: lstmrecognizer.h:292
tesseract::LT_CTC
Definition: static_shape.h:31
tesseract::Network::Backward
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
tesseract::LSTMRecognizer::DebugActivationPath
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
Definition: lstmrecognizer.cpp:392
tesseract::LSTMTrainer::learning_iteration_
int learning_iteration_
Definition: lstmtrainer.h:443
tesseract::LSTMTrainer::PrepareForBackward
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
Definition: lstmtrainer.cpp:770
STRING::add_str_int
void add_str_int(const char *str, int number)
Definition: strngs.cpp:370
tesseract::LSTMTrainer::UpdateSubtrainer
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
Definition: lstmtrainer.cpp:519
tesseract::LSTMTrainer::SaveTrainingDump
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
Definition: lstmtrainer.cpp:874
SVET_CLICK
Definition: scrollview.h:47
UNICHARSET::encode_string
bool encode_string(const char *str, bool give_up_on_failure, GenericVector< UNICHAR_ID > *encoding, GenericVector< char > *lengths, int *encoded_length) const
Definition: unicharset.cpp:258
tesseract::TessdataManager::SetVersionString
void SetVersionString(const std::string &v_str)
Definition: tessdatamanager.cpp:239
tesseract::Network::SetEnableTraining
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:110
tesseract::LSTMTrainer::SaveRecognitionDump
void SaveRecognitionDump(GenericVector< char > *data) const
Definition: lstmtrainer.cpp:904
tesseract::LSTMTrainer::mgr_
TessdataManager mgr_
Definition: lstmtrainer.h:462
tesseract::LSTMRecognizer::ScaleLayerLearningRate
void ScaleLayerLearningRate(const STRING &id, double factor)
Definition: lstmrecognizer.h:116
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:87
tesseract::LSTMTrainer::ReadSizedTrainingDump
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:282
tesseract::LSTMTrainer::best_error_rates_
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:411
tesseract::LSTMTrainer::num_training_stages_
int num_training_stages_
Definition: lstmtrainer.h:404
tesseract::LSTMTrainer::MapRecoder
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
Definition: lstmtrainer.cpp:933
tesseract::LoadDataFromFile
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
Definition: genericvector.h:341
tesseract::kTargetXScale
const int kTargetXScale
Definition: lstmtrainer.cpp:71
tesseract::LSTMTrainer::prev_sample_iteration_
int prev_sample_iteration_
Definition: lstmtrainer.h:445
tesseract::LSTMRecognizer::training_iteration
int training_iteration() const
Definition: lstmrecognizer.h:60
tesseract::LSTMTrainer::worst_error_rate_
double worst_error_rate_
Definition: lstmtrainer.h:415
tesseract::TessdataManager::VersionString
std::string VersionString() const
Definition: tessdatamanager.cpp:233
tesseract::LSTMTrainer::checkpoint_name_
STRING checkpoint_name_
Definition: lstmtrainer.h:397
tesseract::DocumentCache::LoadDocuments
bool LoadDocuments(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, FileReader reader)
Definition: imagedata.cpp:566
tesseract::LSTMTrainer::best_error_rate_
double best_error_rate_
Definition: lstmtrainer.h:409
tesseract::LSTMTrainer::DumpFilename
STRING DumpFilename() const
Definition: lstmtrainer.cpp:914
tesseract::countof
constexpr size_t countof(T const (&)[N]) noexcept
Definition: serialis.h:41
tesseract::LSTMRecognizer::randomizer_
TRand randomizer_
Definition: lstmrecognizer.h:289
tesseract::LSTMTrainer::debug_interval_
int debug_interval_
Definition: lstmtrainer.h:391
tesseract::LSTMRecognizer::learning_rate
double learning_rate() const
Definition: lstmrecognizer.h:62
tesseract::LSTMTrainer::perfect_delay_
int perfect_delay_
Definition: lstmtrainer.h:451
tesseract::kNumAdjustmentIterations
const int kNumAdjustmentIterations
Definition: lstmtrainer.cpp:55
tesseract::Network::RemapOutputs
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
Definition: network.h:186
STRING
Definition: strngs.h:45
tesseract::UnicharCompress::code_range
int code_range() const
Definition: unicharcompress.h:161
tesseract::LSTMTrainer::error_buffers_
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:458
GenericVector::Serialize
bool Serialize(FILE *fp) const
Definition: genericvector.h:929
tesseract::SaveDataToFile
bool SaveDataToFile(const GenericVector< char > &data, const char *filename)
Definition: genericvector.h:362
tesseract::LSTMRecognizer::sample_iteration_
int32_t sample_iteration_
Definition: lstmrecognizer.h:278
tesseract::LSTMTrainer::EncodeString
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:232
IntCastRounded
int IntCastRounded(double x)
Definition: helpers.h:173
tesseract::LSTMRecognizer::network_str_
STRING network_str_
Definition: lstmrecognizer.h:271
tesseract::LSTMRecognizer::EnumerateLayers
GenericVector< STRING > EnumerateLayers() const
Definition: lstmrecognizer.h:79
tesseract::LSTMRecognizer::LabelsFromOutputs
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
Definition: lstmrecognizer.cpp:462
tesseract::kMinStartedErrorRate
const int kMinStartedErrorRate
Definition: lstmtrainer.cpp:61
tesseract::LSTMRecognizer::DisplayForward
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
Definition: lstmrecognizer.cpp:349
tesseract::LSTMRecognizer::GetUnicharset
const UNICHARSET & GetUnicharset() const
Definition: lstmrecognizer.h:132
tesseract::LSTMTrainer::model_base_
STRING model_base_
Definition: lstmtrainer.h:395
tesseract::Network::DebugWeights
virtual void DebugWeights()=0
tesseract::LSTMTrainer::TrainOnLine
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:245
tesseract::LSTMTrainer::best_error_history_
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:436
tesseract::Network::TestFlag
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
tesseract::ET_WORD_RECERR
Definition: lstmtrainer.h:40
tesseract::Network::IsTraining
bool IsTraining() const
Definition: network.h:115
tesseract::PERFECT
Definition: lstmtrainer.h:49
tesseract::LSTMRecognizer::GetLayerLearningRate
float GetLayerLearningRate(const STRING &id) const
Definition: lstmrecognizer.h:94
tesseract::LSTMTrainer::best_trainer_
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:426
GenericVector::reverse
void reverse()
Definition: genericvector.h:215
tesseract::LSTMTrainer::StartSubtrainer
void StartSubtrainer(STRING *log_msg)
Definition: lstmtrainer.cpp:489
tesseract::TessdataManager::OverwriteEntry
void OverwriteEntry(TessdataType type, const char *data, int size)
Definition: tessdatamanager.cpp:145
tesseract::LSTMRecognizer::LoadCharsets
bool LoadCharsets(const TessdataManager *mgr)
Definition: lstmrecognizer.cpp:133
tesseract::Network::OutputShape
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133
tesseract::FULL
Definition: lstmtrainer.h:59
tesseract::LSTMTrainer::error_rate_of_last_saved_best_
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:431
tesseract::LSTMTrainer::NewSingleError
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:140
tesseract::LSTMRecognizer::DeSerialize
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
Definition: lstmrecognizer.cpp:108
UNICHAR_BROKEN
Definition: unicharset.h:36
tesseract::LSTMTrainer::RollErrorBuffers
void RollErrorBuffers()
Definition: lstmtrainer.cpp:1241
GenericVector::push_back
int push_back(T object)
Definition: genericvector.h:799
tesseract::LSTMRecognizer::RecognizeLine
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
Definition: lstmrecognizer.cpp:187
tesseract::LSTMRecognizer::scratch_space_
NetworkScratch scratch_space_
Definition: lstmrecognizer.h:290
tesseract::LSTMTrainer::align_win_
ScrollView * align_win_
Definition: lstmtrainer.h:383
tesseract::LossType
LossType
Definition: static_shape.h:29
tesseract::LSTMRecognizer::training_flags_
int32_t training_flags_
Definition: lstmrecognizer.h:274
tesseract::LSTMRecognizer::null_char
int null_char() const
Definition: lstmrecognizer.h:145
tesseract::LSTMRecognizer::SetRandomSeed
void SetRandomSeed()
Definition: lstmrecognizer.h:217
tesseract::LSTMTrainer::CharError
double CharError() const
Definition: lstmtrainer.h:125
STRING::c_str
const char * c_str() const
Definition: strngs.cpp:192
tesseract::LSTMRecognizer::momentum_
float momentum_
Definition: lstmrecognizer.h:284
tesseract::LSTMRecognizer::debug_win_
ScrollView * debug_win_
Definition: lstmrecognizer.h:298
GenericVector::DeSerialize
bool DeSerialize(bool swap, FILE *fp)
Definition: genericvector.h:954
tesseract::LSTMTrainer::best_model_data_
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:423
tesseract::LSTMTrainer::UpdateErrorBuffer
void UpdateErrorBuffer(double new_error, ErrorTypes type)
Definition: lstmtrainer.cpp:1228
tesseract::LSTMTrainer::SetNullChar
void SetNullChar()
Definition: lstmtrainer.cpp:981
tesseract::LSTMTrainer::learning_iteration
int learning_iteration() const
Definition: lstmtrainer.h:135
tesseract::LSTMTrainer::ComputeRMSError
double ComputeRMSError(const NetworkIO &deltas)
Definition: lstmtrainer.cpp:1130
UNICHARSET::unichar_to_id
UNICHAR_ID unichar_to_id(const char *const unichar_repr) const
Definition: unicharset.cpp:209
tesseract::LSTMRecognizer::OutputLossType
LossType OutputLossType() const
Definition: lstmrecognizer.h:63
tesseract::TESSDATA_LSTM_RECODER
Definition: tessdatamanager.h:79
UNICHAR_SPACE
Definition: unicharset.h:34
tesseract::LSTMRecognizer::ScaleLearningRate
void ScaleLearningRate(double factor)
Definition: lstmrecognizer.h:105
tesseract::Network::InputShape
virtual StaticShape InputShape() const
Definition: network.h:127
tesseract::LSTMTrainer::ComputeErrorRates
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
Definition: lstmtrainer.cpp:1110
tesseract::STR_UPDATED
Definition: lstmtrainer.h:65
tesseract::kNumPagesPerBatch
const int kNumPagesPerBatch
Definition: lstmtrainer.cpp:59
GenericVector::empty
bool empty() const
Definition: genericvector.h:86
UNICHARSET
Definition: unicharset.h:145
tesseract::LSTMTrainer::improvement_steps_
int32_t improvement_steps_
Definition: lstmtrainer.h:439
tesseract::LSTMRecognizer::GetLayer
Network * GetLayer(const STRING &id) const
Definition: lstmrecognizer.h:87
tesseract::LSTMTrainer::ctc_win_
ScrollView * ctc_win_
Definition: lstmtrainer.h:387
tesseract::kStageTransitionThreshold
const double kStageTransitionThreshold
Definition: lstmtrainer.cpp:63
tesseract::LSTMRecognizer::Serialize
bool Serialize(const TessdataManager *mgr, TFile *fp) const
Definition: lstmrecognizer.cpp:89
tesseract::LSTMRecognizer::recoder_
UnicharCompress recoder_
Definition: lstmrecognizer.h:268
UNICHARSET::CleanupString
static std::string CleanupString(const char *utf8_str)
Definition: unicharset.h:246
tesseract::Trainability
Trainability
Definition: lstmtrainer.h:47
tesseract::LSTMRecognizer::training_iteration_
int32_t training_iteration_
Definition: lstmrecognizer.h:276
tesseract::kTargetYScale
const int kTargetYScale
Definition: lstmtrainer.cpp:72
tesseract::TS_RE_ENABLE
Definition: network.h:99
tesseract::LSTMTrainer::worst_iteration_
int worst_iteration_
Definition: lstmtrainer.h:419
tesseract::LSTMTrainer::ComputeWordError
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
Definition: lstmtrainer.cpp:1195
tesseract::LSTMTrainer::target_win_
ScrollView * target_win_
Definition: lstmtrainer.h:385
tesseract::LSTMTrainer::stall_iteration_
int stall_iteration_
Definition: lstmtrainer.h:421
tesseract::LSTMTrainer::ReadTrainingDump
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:277
tesseract::LSTMTrainer::training_stage_
int training_stage_
Definition: lstmtrainer.h:433
tesseract::SubTrainerResult
SubTrainerResult
Definition: lstmtrainer.h:63
tesseract::kLearningRateDecay
const double kLearningRateDecay
Definition: lstmtrainer.cpp:53
tesseract::LSTMRecognizer::SimpleTextOutput
bool SimpleTextOutput() const
Definition: lstmrecognizer.h:69
tesseract::LSTMTrainer::DisplayTargets
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
Definition: lstmtrainer.cpp:1042
tesseract::LSTMRecognizer::null_char_
int32_t null_char_
Definition: lstmrecognizer.h:281
UNICHARSET::has_special_codes
bool has_special_codes() const
Definition: unicharset.h:712
tesseract::Network::NumOutputs
int NumOutputs() const
Definition: network.h:123
tesseract::ET_CHAR_ERROR
Definition: lstmtrainer.h:41
tesseract::TRand::SignedRand
double SignedRand(double range)
Definition: helpers.h:85
tesseract::LSTMTrainer::best_iteration_
int best_iteration_
Definition: lstmtrainer.h:413
GenericVector< int >
tesseract::LSTMTrainer::sub_trainer_
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:429
tesseract::LSTMTrainer::DeSerialize
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
Definition: lstmtrainer.cpp:440
tesseract::kImprovementFraction
const double kImprovementFraction
Definition: lstmtrainer.cpp:67
tesseract::LT_SOFTMAX
Definition: static_shape.h:32
ScrollView::AwaitEvent
SVEvent * AwaitEvent(SVEventType type)
Definition: scrollview.cpp:443
tesseract::HI_PRECISION_ERR
Definition: lstmtrainer.h:51
tesseract::ET_RMS
Definition: lstmtrainer.h:38
STRING::length
int32_t length() const
Definition: strngs.cpp:187
tesseract::kBestCheckpointFraction
const double kBestCheckpointFraction
Definition: lstmtrainer.cpp:69
tesseract::LSTMTrainer::EmptyConstructor
void EmptyConstructor()
Definition: lstmtrainer.cpp:990
tesseract::Network::ClearWindow
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:312
tesseract::Network::Update
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
Definition: network.h:230
GenericVector::truncate
void truncate(int size)
Definition: genericvector.h:132
tesseract::LSTMTrainer::ComputeCTCTargets
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
Definition: lstmtrainer.cpp:1099
tesseract::LSTMTrainer::recon_win_
ScrollView * recon_win_
Definition: lstmtrainer.h:389
tesseract::LSTMTrainer::InitCharSet
void InitCharSet()
Definition: lstmtrainer.cpp:968
tesseract::TessdataManager::Init
bool Init(const char *data_file_name)
Definition: tessdatamanager.cpp:97
tesseract::LSTMRecognizer::IsRecoding
bool IsRecoding() const
Definition: lstmrecognizer.h:72
tesseract::CTC::NormalizeProbs
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36
STRING::add_str_double
void add_str_double(const char *str, double number)
Definition: strngs.cpp:380
tesseract::LSTMTrainer::ReduceLearningRates
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
Definition: lstmtrainer.cpp:562
tesseract::TS_TEMP_DISABLE
Definition: network.h:97
tesseract::kMinStallIterations
const int kMinStallIterations
Definition: lstmtrainer.cpp:48
GenericVector::init_to_size
void init_to_size(int size, const T &t)
Definition: genericvector.h:706
tesseract::LSTMTrainer::last_perfect_training_iteration_
int last_perfect_training_iteration_
Definition: lstmtrainer.h:454
tesseract::NetworkBuilder::InitNetwork
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
Definition: networkbuilder.cpp:45
tesseract::LSTMRecognizer::NumOutputs
int NumOutputs() const
Definition: lstmrecognizer.h:59
tesseract::LSTMTrainer::CurrentTrainingStage
int CurrentTrainingStage() const
Definition: lstmtrainer.h:197
tesseract::LSTMTrainer::ReduceLayerLearningRates
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
Definition: lstmtrainer.cpp:581
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
tesseract::LSTMRecognizer::network_
Network * network_
Definition: lstmrecognizer.h:261
tesseract::LSTMTrainer::best_error_iterations_
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:437
tesstrain_utils.type
type
Definition: tesstrain_utils.py:141
tesseract::kErrorGraphInterval
const int kErrorGraphInterval
Definition: lstmtrainer.cpp:57
tesseract::LSTMTrainer::UpdateErrorGraph
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
Definition: lstmtrainer.cpp:1260
tesseract::Network::spec
virtual STRING spec() const
Definition: network.h:141
tesseract::LSTMTrainer::worst_model_data_
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:424
tesseract::ET_SKIP_RATIO
Definition: lstmtrainer.h:42
tesseract::LIGHT
Definition: lstmtrainer.h:57
tesseract::NF_LAYER_SPECIFIC_LR
Definition: network.h:87
tesseract::LSTMTrainer::worst_error_rates_
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:417
tesseract::kHighConfidence
const double kHighConfidence
Definition: lstmtrainer.cpp:65
tesseract::LSTMTrainer::kRollingBufferSize_
static const int kRollingBufferSize_
Definition: lstmtrainer.h:457
tesseract::LSTMTrainer::error_rates_
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:460
tesseract::DocumentCache::Clear
void Clear()
Definition: imagedata.h:326
tesseract::LSTMTrainer::training_data_
DocumentCache training_data_
Definition: lstmtrainer.h:400
tesseract::LSTMTrainer::TransitionTrainingStage
bool TransitionTrainingStage(float error_threshold)
Definition: lstmtrainer.cpp:393
tesseract::UnicharCompress::EncodeUnichar
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
Definition: unicharcompress.cpp:283
tesseract::LSTMTrainer::PrepareLogMsg
void PrepareLogMsg(STRING *log_msg) const
Definition: lstmtrainer.cpp:372
tesseract::TF_COMPRESS_UNICHARSET
Definition: lstmrecognizer.h:48
tesseract::UNENCODABLE
Definition: lstmtrainer.h:50
tesseract::TESSDATA_LSTM
Definition: tessdatamanager.h:74
tesseract::LSTMTrainer::ComputeTextTargets
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
Definition: lstmtrainer.cpp:1079
GenericVector::size
int size() const
Definition: genericvector.h:71
ScrollView::GREEN_YELLOW
Definition: scrollview.h:149
tesseract::LSTMRecognizer::adam_beta_
float adam_beta_
Definition: lstmrecognizer.h:286
tesseract::LSTMTrainer::LSTMTrainer
LSTMTrainer()
Definition: lstmtrainer.cpp:74
tesseract::LSTMTrainer::InitIterations
void InitIterations()
Definition: lstmtrainer.cpp:190
tesseract::LSTMTrainer::checkpoint_iteration_
int checkpoint_iteration_
Definition: lstmtrainer.h:393
tesseract::NO_BEST_TRAINER
Definition: lstmtrainer.h:58
tesseract::TRAINABLE
Definition: lstmtrainer.h:48
tesseract::TESSDATA_LSTM_UNICHARSET
Definition: tessdatamanager.h:78
tesseract::CTC::ComputeCTCTargets
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:54
search
LIST search(LIST list, void *key, int_compare is_equal)
Definition: oldlist.cpp:202
tesseract::ET_DELTA
Definition: lstmtrainer.h:39
tesseract::TessdataManager::SaveFile
bool SaveFile(const STRING &filename, FileWriter writer) const
Definition: tessdatamanager.cpp:153
UNICHARSET::size
int size() const
Definition: unicharset.h:341
STRING::split
void split(char c, GenericVector< STRING > *splited)
Definition: strngs.cpp:275
tesseract::LSTMTrainer::ComputeWinnerError
double ComputeWinnerError(const NetworkIO &deltas)
Definition: lstmtrainer.cpp:1149
tesseract::LSTMTrainer::LogIterations
void LogIterations(const char *intro_str, STRING *log_msg) const
Definition: lstmtrainer.cpp:384
tesseract::LSTMTrainer::ComputeCharError
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
Definition: lstmtrainer.cpp:1167
tesseract::LSTMRecognizer::sample_iteration
int sample_iteration() const
Definition: lstmrecognizer.h:61
tesseract::NOT_BOXED
Definition: lstmtrainer.h:52
tesseract::LSTMRecognizer::DecodeLabels
STRING DecodeLabels(const GenericVector< int > &labels)
Definition: lstmrecognizer.cpp:334