tesseract  5.0.0-alpha-619-ge9db
lstmtrainer.cpp
Go to the documentation of this file.
1 // File: lstmtrainer.cpp
3 // Description: Top-level line trainer class for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #define _USE_MATH_DEFINES // needed to get definition of M_SQRT1_2
19 
20 // Include automatically generated configuration file if running autoconf.
21 #ifdef HAVE_CONFIG_H
22 #include "config_auto.h"
23 #endif
24 
25 #include "lstmtrainer.h"
26 #include <string>
27 
28 #include "allheaders.h"
29 #include "boxread.h"
30 #include "ctc.h"
31 #include "imagedata.h"
32 #include "input.h"
33 #include "networkbuilder.h"
34 #include "ratngs.h"
35 #include "recodebeam.h"
36 #ifdef INCLUDE_TENSORFLOW
37 #include "tfnetwork.h"
38 #endif
39 #include "tprintf.h"
40 
41 #include "callcpp.h"
42 
43 namespace tesseract {
44 
45 // Min actual error rate increase to constitute divergence.
46 const double kMinDivergenceRate = 50.0;
47 // Min iterations since last best before acting on a stall.
48 const int kMinStallIterations = 10000;
49 // Fraction of current char error rate that sub_trainer_ has to be ahead
50 // before we declare the sub_trainer_ a success and switch to it.
51 const double kSubTrainerMarginFraction = 3.0 / 128;
52 // Factor to reduce learning rate on divergence.
53 const double kLearningRateDecay = M_SQRT1_2;
54 // LR adjustment iterations.
55 const int kNumAdjustmentIterations = 100;
56 // How often to add data to the error_graph_.
57 const int kErrorGraphInterval = 1000;
58 // Number of training images to train between calls to MaintainCheckpoints.
59 const int kNumPagesPerBatch = 100;
60 // Min percent error rate to consider start-up phase over.
61 const int kMinStartedErrorRate = 75;
62 // Error rate at which to transition to stage 1.
63 const double kStageTransitionThreshold = 10.0;
64 // Confidence beyond which the truth is more likely wrong than the recognizer.
65 const double kHighConfidence = 0.9375; // 15/16.
66 // Fraction of weight sign-changing total to constitute a definite improvement.
67 const double kImprovementFraction = 15.0 / 16.0;
68 // Fraction of last written best to make it worth writing another.
69 const double kBestCheckpointFraction = 31.0 / 32.0;
70 // Scale factor for display of target activations of CTC.
71 const int kTargetXScale = 5;
72 const int kTargetYScale = 100;
73 
75  : randomly_rotate_(false),
76  training_data_(0),
77  sub_trainer_(nullptr) {
79  debug_interval_ = 0;
80 }
81 
82 LSTMTrainer::LSTMTrainer(const char* model_base, const char* checkpoint_name,
83  int debug_interval, int64_t max_memory)
84  : randomly_rotate_(false),
85  training_data_(max_memory),
86  sub_trainer_(nullptr) {
88  debug_interval_ = debug_interval;
89  model_base_ = model_base;
90  checkpoint_name_ = checkpoint_name;
91 }
92 
94  delete align_win_;
95  delete target_win_;
96  delete ctc_win_;
97  delete recon_win_;
98  delete sub_trainer_;
99 }
100 
101 // Tries to deserialize a trainer from the given file and silently returns
102 // false in case of failure.
103 bool LSTMTrainer::TryLoadingCheckpoint(const char* filename,
104  const char* old_traineddata) {
105  GenericVector<char> data;
106  if (!LoadDataFromFile(filename, &data)) return false;
107  tprintf("Loaded file %s, unpacking...\n", filename);
108  if (!ReadTrainingDump(data, this)) return false;
110  if (((old_traineddata == nullptr || *old_traineddata == '\0') &&
112  filename == old_traineddata) {
113  return true; // Normal checkpoint load complete.
114  }
115  tprintf("Code range changed from %d to %d!\n", network_->NumOutputs(),
116  recoder_.code_range());
117  if (old_traineddata == nullptr || *old_traineddata == '\0') {
118  tprintf("Must supply the old traineddata for code conversion!\n");
119  return false;
120  }
121  TessdataManager old_mgr;
122  ASSERT_HOST(old_mgr.Init(old_traineddata));
123  TFile fp;
124  if (!old_mgr.GetComponent(TESSDATA_LSTM_UNICHARSET, &fp)) return false;
125  UNICHARSET old_chset;
126  if (!old_chset.load_from_file(&fp, false)) return false;
127  if (!old_mgr.GetComponent(TESSDATA_LSTM_RECODER, &fp)) return false;
128  UnicharCompress old_recoder;
129  if (!old_recoder.DeSerialize(&fp)) return false;
130  std::vector<int> code_map = MapRecoder(old_chset, old_recoder);
131  // Set the null_char_ to the new value.
132  int old_null_char = null_char_;
133  SetNullChar();
134  // Map the softmax(s) in the network.
135  network_->RemapOutputs(old_recoder.code_range(), code_map);
136  tprintf("Previous null char=%d mapped to %d\n", old_null_char, null_char_);
137  return true;
138 }
139 
140 // Initializes the trainer with a network_spec in the network description
141 // net_flags control network behavior according to the NetworkFlags enum.
142 // There isn't really much difference between them - only where the effects
143 // are implemented.
144 // For other args see NetworkBuilder::InitNetwork.
145 // Note: Be sure to call InitCharSet before InitNetwork!
146 bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index,
147  int net_flags, float weight_range,
148  float learning_rate, float momentum,
149  float adam_beta) {
150  mgr_.SetVersionString(mgr_.VersionString() + ":" + network_spec.c_str());
151  adam_beta_ = adam_beta;
153  momentum_ = momentum;
154  SetNullChar();
155  if (!NetworkBuilder::InitNetwork(recoder_.code_range(), network_spec,
156  append_index, net_flags, weight_range,
157  &randomizer_, &network_)) {
158  return false;
159  }
160  network_str_ += network_spec;
161  tprintf("Built network:%s from request %s\n",
162  network_->spec().c_str(), network_spec.c_str());
163  tprintf(
164  "Training parameters:\n Debug interval = %d,"
165  " weights = %g, learning rate = %g, momentum=%g\n",
166  debug_interval_, weight_range, learning_rate_, momentum_);
167  tprintf("null char=%d\n", null_char_);
168  return true;
169 }
170 
171 // Initializes a trainer from a serialized TFNetworkModel proto.
172 // Returns the global step of TensorFlow graph or 0 if failed.
173 #ifdef INCLUDE_TENSORFLOW
174 int LSTMTrainer::InitTensorFlowNetwork(const std::string& tf_proto) {
175  delete network_;
176  TFNetwork* tf_net = new TFNetwork("TensorFlow");
177  training_iteration_ = tf_net->InitFromProtoStr(tf_proto);
178  if (training_iteration_ == 0) {
179  tprintf("InitFromProtoStr failed!!\n");
180  return 0;
181  }
182  network_ = tf_net;
183  ASSERT_HOST(recoder_.code_range() == tf_net->num_classes());
184  return training_iteration_;
185 }
186 #endif
187 
188 // Resets all the iteration counters for fine tuning or traininng a head,
189 // where we want the error reporting to reset.
191  sample_iteration_ = 0;
195  best_error_rate_ = 100.0;
196  best_iteration_ = 0;
197  worst_error_rate_ = 0.0;
198  worst_iteration_ = 0;
201  perfect_delay_ = 0;
203  for (int i = 0; i < ET_COUNT; ++i) {
204  best_error_rates_[i] = 100.0;
205  worst_error_rates_[i] = 0.0;
207  error_rates_[i] = 100.0;
208  }
210 }
211 
212 // If the training sample is usable, grid searches for the optimal
213 // dict_ratio/cert_offset, and returns the results in a string of space-
214 // separated triplets of ratio,offset=worderr.
216  const ImageData* trainingdata, int iteration, double min_dict_ratio,
217  double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
218  double cert_offset_step, double max_cert_offset, STRING* results) {
219  sample_iteration_ = iteration;
220  NetworkIO fwd_outputs, targets;
221  Trainability result =
222  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
223  if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == nullptr)
224  return result;
225 
226  // Encode/decode the truth to get the normalization.
227  GenericVector<int> truth_labels, ocr_labels, xcoords;
228  ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels));
229  // NO-dict error.
230  RecodeBeamSearch base_search(recoder_, null_char_, SimpleTextOutput(), nullptr);
231  base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty,
232  nullptr);
233  base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
234  STRING truth_text = DecodeLabels(truth_labels);
235  STRING ocr_text = DecodeLabels(ocr_labels);
236  double baseline_error = ComputeWordError(&truth_text, &ocr_text);
237  results->add_str_double("0,0=", baseline_error);
238 
240  for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
241  for (double c = min_cert_offset; c < max_cert_offset;
242  c += cert_offset_step) {
243  search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty, nullptr);
244  search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
245  truth_text = DecodeLabels(truth_labels);
246  ocr_text = DecodeLabels(ocr_labels);
247  // This is destructive on both strings.
248  double word_error = ComputeWordError(&truth_text, &ocr_text);
249  if ((r == min_dict_ratio && c == min_cert_offset) ||
250  !std::isfinite(word_error)) {
251  STRING t = DecodeLabels(truth_labels);
252  STRING o = DecodeLabels(ocr_labels);
253  tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
254  t.c_str(), o.c_str(), word_error, truth_labels[0]);
255  }
256  results->add_str_double(" ", r);
257  results->add_str_double(",", c);
258  results->add_str_double("=", word_error);
259  }
260  }
261  return result;
262 }
263 
264 // Provides output on the distribution of weight values.
267 }
268 
269 // Loads a set of lstmf files that were created using the lstm.train config to
270 // tesseract into memory ready for training. Returns false if nothing was
271 // loaded.
273  CachingStrategy cache_strategy,
274  bool randomly_rotate) {
275  randomly_rotate_ = randomly_rotate;
277  return training_data_.LoadDocuments(filenames, cache_strategy,
279 }
280 
281 // Keeps track of best and locally worst char error_rate and launches tests
282 // using tester, when a new min or max is reached.
283 // Writes checkpoints at appropriate times and builds and returns a log message
284 // to indicate progress. Returns false if nothing interesting happened.
286  PrepareLogMsg(log_msg);
287  double error_rate = CharError();
288  int iteration = learning_iteration();
289  if (iteration >= stall_iteration_ &&
290  error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) &&
292  // It hasn't got any better in a long while, and is a margin worse than the
293  // best, so go back to the best model and try a different learning rate.
294  StartSubtrainer(log_msg);
295  }
296  SubTrainerResult sub_trainer_result = STR_NONE;
297  if (sub_trainer_ != nullptr) {
298  sub_trainer_result = UpdateSubtrainer(log_msg);
299  if (sub_trainer_result == STR_REPLACED) {
300  // Reset the inputs, as we have overwritten *this.
301  error_rate = CharError();
302  iteration = learning_iteration();
303  PrepareLogMsg(log_msg);
304  }
305  }
306  bool result = true; // Something interesting happened.
307  GenericVector<char> rec_model_data;
308  if (error_rate < best_error_rate_) {
309  SaveRecognitionDump(&rec_model_data);
310  log_msg->add_str_double(" New best char error = ", error_rate);
311  *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
312  // If sub_trainer_ is not nullptr, either *this beat it to a new best, or it
313  // just overwrote *this. In either case, we have finished with it.
314  delete sub_trainer_;
315  sub_trainer_ = nullptr;
318  log_msg->add_str_int(" Transitioned to stage ", CurrentTrainingStage());
319  }
322  STRING best_model_name = DumpFilename();
323  if (!SaveDataToFile(best_trainer_, best_model_name.c_str())) {
324  *log_msg += " failed to write best model:";
325  } else {
326  *log_msg += " wrote best model:";
328  }
329  *log_msg += best_model_name;
330  }
331  } else if (error_rate > worst_error_rate_) {
332  SaveRecognitionDump(&rec_model_data);
333  log_msg->add_str_double(" New worst char error = ", error_rate);
334  *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
337  // Error rate has ballooned. Go back to the best model.
338  *log_msg += "\nDivergence! ";
339  // Copy best_trainer_ before reading it, as it will get overwritten.
340  GenericVector<char> revert_data(best_trainer_);
341  if (ReadTrainingDump(revert_data, this)) {
342  LogIterations("Reverted to", log_msg);
343  ReduceLearningRates(this, log_msg);
344  } else {
345  LogIterations("Failed to Revert at", log_msg);
346  }
347  // If it fails again, we will wait twice as long before reverting again.
348  stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
349  // Re-save the best trainer with the new learning rates and stall
350  // iteration.
352  }
353  } else {
354  // Something interesting happened only if the sub_trainer_ was trained.
355  result = sub_trainer_result != STR_NONE;
356  }
357  if (checkpoint_name_.length() > 0) {
358  // Write a current checkpoint.
359  GenericVector<char> checkpoint;
360  if (!SaveTrainingDump(FULL, this, &checkpoint) ||
361  !SaveDataToFile(checkpoint, checkpoint_name_.c_str())) {
362  *log_msg += " failed to write checkpoint.";
363  } else {
364  *log_msg += " wrote checkpoint.";
365  }
366  }
367  *log_msg += "\n";
368  return result;
369 }
370 
371 // Builds a string containing a progress message with current error rates.
372 void LSTMTrainer::PrepareLogMsg(STRING* log_msg) const {
373  LogIterations("At", log_msg);
374  log_msg->add_str_double(", Mean rms=", error_rates_[ET_RMS]);
375  log_msg->add_str_double("%, delta=", error_rates_[ET_DELTA]);
376  log_msg->add_str_double("%, char train=", error_rates_[ET_CHAR_ERROR]);
377  log_msg->add_str_double("%, word train=", error_rates_[ET_WORD_RECERR]);
378  log_msg->add_str_double("%, skip ratio=", error_rates_[ET_SKIP_RATIO]);
379  *log_msg += "%, ";
380 }
381 
382 // Appends <intro_str> iteration learning_iteration()/training_iteration()/
383 // sample_iteration() to the log_msg.
384 void LSTMTrainer::LogIterations(const char* intro_str, STRING* log_msg) const {
385  *log_msg += intro_str;
386  log_msg->add_str_int(" iteration ", learning_iteration());
387  log_msg->add_str_int("/", training_iteration());
388  log_msg->add_str_int("/", sample_iteration());
389 }
390 
391 // Returns true and increments the training_stage_ if the error rate has just
392 // passed through the given threshold for the first time.
393 bool LSTMTrainer::TransitionTrainingStage(float error_threshold) {
394  if (best_error_rate_ < error_threshold &&
396  ++training_stage_;
397  return true;
398  }
399  return false;
400 }
401 
402 // Writes to the given file. Returns false in case of error.
404  const TessdataManager* mgr, TFile* fp) const {
405  if (!LSTMRecognizer::Serialize(mgr, fp)) return false;
406  if (!fp->Serialize(&learning_iteration_)) return false;
407  if (!fp->Serialize(&prev_sample_iteration_)) return false;
408  if (!fp->Serialize(&perfect_delay_)) return false;
409  if (!fp->Serialize(&last_perfect_training_iteration_)) return false;
410  for (const auto & error_buffer : error_buffers_) {
411  if (!error_buffer.Serialize(fp)) return false;
412  }
413  if (!fp->Serialize(&error_rates_[0], countof(error_rates_))) return false;
414  if (!fp->Serialize(&training_stage_)) return false;
415  uint8_t amount = serialize_amount;
416  if (!fp->Serialize(&amount)) return false;
417  if (serialize_amount == LIGHT) return true; // We are done.
418  if (!fp->Serialize(&best_error_rate_)) return false;
419  if (!fp->Serialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
420  if (!fp->Serialize(&best_iteration_)) return false;
421  if (!fp->Serialize(&worst_error_rate_)) return false;
422  if (!fp->Serialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
423  if (!fp->Serialize(&worst_iteration_)) return false;
424  if (!fp->Serialize(&stall_iteration_)) return false;
425  if (!best_model_data_.Serialize(fp)) return false;
426  if (!worst_model_data_.Serialize(fp)) return false;
427  if (serialize_amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp))
428  return false;
429  GenericVector<char> sub_data;
430  if (sub_trainer_ != nullptr && !SaveTrainingDump(LIGHT, sub_trainer_, &sub_data))
431  return false;
432  if (!sub_data.Serialize(fp)) return false;
433  if (!best_error_history_.Serialize(fp)) return false;
434  if (!best_error_iterations_.Serialize(fp)) return false;
435  return fp->Serialize(&improvement_steps_);
436 }
437 
438 // Reads from the given file. Returns false in case of error.
439 // NOTE: It is assumed that the trainer is never read cross-endian.
441  if (!LSTMRecognizer::DeSerialize(mgr, fp)) return false;
442  if (!fp->DeSerialize(&learning_iteration_)) {
443  // Special case. If we successfully decoded the recognizer, but fail here
444  // then it means we were just given a recognizer, so issue a warning and
445  // allow it.
446  tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
449  return true;
450  }
451  if (!fp->DeSerialize(&prev_sample_iteration_)) return false;
452  if (!fp->DeSerialize(&perfect_delay_)) return false;
453  if (!fp->DeSerialize(&last_perfect_training_iteration_)) return false;
454  for (auto & error_buffer : error_buffers_) {
455  if (!error_buffer.DeSerialize(fp)) return false;
456  }
457  if (!fp->DeSerialize(&error_rates_[0], countof(error_rates_))) return false;
458  if (!fp->DeSerialize(&training_stage_)) return false;
459  uint8_t amount;
460  if (!fp->DeSerialize(&amount)) return false;
461  if (amount == LIGHT) return true; // Don't read the rest.
462  if (!fp->DeSerialize(&best_error_rate_)) return false;
463  if (!fp->DeSerialize(&best_error_rates_[0], countof(best_error_rates_))) return false;
464  if (!fp->DeSerialize(&best_iteration_)) return false;
465  if (!fp->DeSerialize(&worst_error_rate_)) return false;
466  if (!fp->DeSerialize(&worst_error_rates_[0], countof(worst_error_rates_))) return false;
467  if (!fp->DeSerialize(&worst_iteration_)) return false;
468  if (!fp->DeSerialize(&stall_iteration_)) return false;
469  if (!best_model_data_.DeSerialize(fp)) return false;
470  if (!worst_model_data_.DeSerialize(fp)) return false;
471  if (amount != NO_BEST_TRAINER && !best_trainer_.DeSerialize(fp)) return false;
472  GenericVector<char> sub_data;
473  if (!sub_data.DeSerialize(fp)) return false;
474  delete sub_trainer_;
475  if (sub_data.empty()) {
476  sub_trainer_ = nullptr;
477  } else {
478  sub_trainer_ = new LSTMTrainer();
479  if (!ReadTrainingDump(sub_data, sub_trainer_)) return false;
480  }
481  if (!best_error_history_.DeSerialize(fp)) return false;
482  if (!best_error_iterations_.DeSerialize(fp)) return false;
483  return fp->DeSerialize(&improvement_steps_);
484 }
485 
486 // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
487 // learning rates (by scaling reduction, or layer specific, according to
488 // NF_LAYER_SPECIFIC_LR).
490  delete sub_trainer_;
491  sub_trainer_ = new LSTMTrainer();
493  *log_msg += " Failed to revert to previous best for trial!";
494  delete sub_trainer_;
495  sub_trainer_ = nullptr;
496  } else {
497  log_msg->add_str_int(" Trial sub_trainer_ from iteration ",
499  // Reduce learning rate so it doesn't diverge this time.
500  sub_trainer_->ReduceLearningRates(this, log_msg);
501  // If it fails again, we will wait twice as long before reverting again.
502  int stall_offset =
504  stall_iteration_ = learning_iteration() + 2 * stall_offset;
506  // Re-save the best trainer with the new learning rates and stall iteration.
508  }
509 }
510 
511 // While the sub_trainer_ is behind the current training iteration and its
512 // training error is at least kSubTrainerMarginFraction better than the
513 // current training error, trains the sub_trainer_, and returns STR_UPDATED if
514 // it did anything. If it catches up, and has a better error rate than the
515 // current best, as well as a margin over the current error rate, then the
516 // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
517 // returned. STR_NONE is returned if the subtrainer wasn't good enough to
518 // receive any training iterations.
520  double training_error = CharError();
521  double sub_error = sub_trainer_->CharError();
522  double sub_margin = (training_error - sub_error) / sub_error;
523  if (sub_margin >= kSubTrainerMarginFraction) {
524  log_msg->add_str_double(" sub_trainer=", sub_error);
525  log_msg->add_str_double(" margin=", 100.0 * sub_margin);
526  *log_msg += "\n";
527  // Catch up to current iteration.
528  int end_iteration = training_iteration();
529  while (sub_trainer_->training_iteration() < end_iteration &&
530  sub_margin >= kSubTrainerMarginFraction) {
531  int target_iteration =
533  while (sub_trainer_->training_iteration() < target_iteration) {
534  sub_trainer_->TrainOnLine(this, false);
535  }
536  STRING batch_log = "Sub:";
537  sub_trainer_->PrepareLogMsg(&batch_log);
538  batch_log += "\n";
539  tprintf("UpdateSubtrainer:%s", batch_log.c_str());
540  *log_msg += batch_log;
541  sub_error = sub_trainer_->CharError();
542  sub_margin = (training_error - sub_error) / sub_error;
543  }
544  if (sub_error < best_error_rate_ &&
545  sub_margin >= kSubTrainerMarginFraction) {
546  // The sub_trainer_ has won the race to a new best. Switch to it.
547  GenericVector<char> updated_trainer;
548  SaveTrainingDump(LIGHT, sub_trainer_, &updated_trainer);
549  ReadTrainingDump(updated_trainer, this);
550  log_msg->add_str_int(" Sub trainer wins at iteration ",
552  *log_msg += "\n";
553  return STR_REPLACED;
554  }
555  return STR_UPDATED;
556  }
557  return STR_NONE;
558 }
559 
560 // Reduces network learning rates, either for everything, or for layers
561 // independently, according to NF_LAYER_SPECIFIC_LR.
563  STRING* log_msg) {
565  int num_reduced = ReduceLayerLearningRates(
566  kLearningRateDecay, kNumAdjustmentIterations, samples_trainer);
567  log_msg->add_str_int("\nReduced learning rate on layers: ", num_reduced);
568  } else {
570  log_msg->add_str_double("\nReduced learning rate to :", learning_rate_);
571  }
572  *log_msg += "\n";
573 }
574 
575 // Considers reducing the learning rate independently for each layer down by
576 // factor(<1), or leaving it the same, by double-training the given number of
577 // samples and minimizing the amount of changing of sign of weight updates.
578 // Even if it looks like all weights should remain the same, an adjustment
579 // will be made to guarantee a different result when reverting to an old best.
580 // Returns the number of layer learning rates that were reduced.
581 int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
582  LSTMTrainer* samples_trainer) {
583  enum WhichWay {
584  LR_DOWN, // Learning rate will go down by factor.
585  LR_SAME, // Learning rate will stay the same.
586  LR_COUNT // Size of arrays.
587  };
589  int num_layers = layers.size();
590  GenericVector<int> num_weights;
591  num_weights.init_to_size(num_layers, 0);
592  GenericVector<double> bad_sums[LR_COUNT];
593  GenericVector<double> ok_sums[LR_COUNT];
594  for (int i = 0; i < LR_COUNT; ++i) {
595  bad_sums[i].init_to_size(num_layers, 0.0);
596  ok_sums[i].init_to_size(num_layers, 0.0);
597  }
598  double momentum_factor = 1.0 / (1.0 - momentum_);
599  GenericVector<char> orig_trainer;
600  samples_trainer->SaveTrainingDump(LIGHT, this, &orig_trainer);
601  for (int i = 0; i < num_layers; ++i) {
602  Network* layer = GetLayer(layers[i]);
603  num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
604  }
605  int iteration = sample_iteration();
606  for (int s = 0; s < num_samples; ++s) {
607  // Which way will we modify the learning rate?
608  for (int ww = 0; ww < LR_COUNT; ++ww) {
609  // Transfer momentum to learning rate and adjust by the ww factor.
610  float ww_factor = momentum_factor;
611  if (ww == LR_DOWN) ww_factor *= factor;
612  // Make a copy of *this, so we can mess about without damaging anything.
613  LSTMTrainer copy_trainer;
614  samples_trainer->ReadTrainingDump(orig_trainer, &copy_trainer);
615  // Clear the updates, doing nothing else.
616  copy_trainer.network_->Update(0.0, 0.0, 0.0, 0);
617  // Adjust the learning rate in each layer.
618  for (int i = 0; i < num_layers; ++i) {
619  if (num_weights[i] == 0) continue;
620  copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor);
621  }
622  copy_trainer.SetIteration(iteration);
623  // Train on the sample, but keep the update in updates_ instead of
624  // applying to the weights.
625  const ImageData* trainingdata =
626  copy_trainer.TrainOnLine(samples_trainer, true);
627  if (trainingdata == nullptr) continue;
628  // We'll now use this trainer again for each layer.
629  GenericVector<char> updated_trainer;
630  samples_trainer->SaveTrainingDump(LIGHT, &copy_trainer, &updated_trainer);
631  for (int i = 0; i < num_layers; ++i) {
632  if (num_weights[i] == 0) continue;
633  LSTMTrainer layer_trainer;
634  samples_trainer->ReadTrainingDump(updated_trainer, &layer_trainer);
635  Network* layer = layer_trainer.GetLayer(layers[i]);
636  // Update the weights in just the layer, using Adam if enabled.
637  layer->Update(0.0, momentum_, adam_beta_,
638  layer_trainer.training_iteration_ + 1);
639  // Zero the updates matrix again.
640  layer->Update(0.0, 0.0, 0.0, 0);
641  // Train again on the same sample, again holding back the updates.
642  layer_trainer.TrainOnLine(trainingdata, true);
643  // Count the sign changes in the updates in layer vs in copy_trainer.
644  float before_bad = bad_sums[ww][i];
645  float before_ok = ok_sums[ww][i];
646  layer->CountAlternators(*copy_trainer.GetLayer(layers[i]),
647  &ok_sums[ww][i], &bad_sums[ww][i]);
648  float bad_frac =
649  bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
650  if (bad_frac > 0.0f)
651  bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
652  }
653  }
654  ++iteration;
655  }
656  int num_lowered = 0;
657  for (int i = 0; i < num_layers; ++i) {
658  if (num_weights[i] == 0) continue;
659  Network* layer = GetLayer(layers[i]);
660  float lr = GetLayerLearningRate(layers[i]);
661  double total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
662  double total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
663  double frac_down = bad_sums[LR_DOWN][i] / total_down;
664  double frac_same = bad_sums[LR_SAME][i] / total_same;
665  tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().c_str(),
666  lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
667  if (frac_down < frac_same * kImprovementFraction) {
668  tprintf(" REDUCED\n");
669  ScaleLayerLearningRate(layers[i], factor);
670  ++num_lowered;
671  } else {
672  tprintf(" SAME\n");
673  }
674  }
675  if (num_lowered == 0) {
676  // Just lower everything to make sure.
677  for (int i = 0; i < num_layers; ++i) {
678  if (num_weights[i] > 0) {
679  ScaleLayerLearningRate(layers[i], factor);
680  ++num_lowered;
681  }
682  }
683  }
684  return num_lowered;
685 }
686 
687 // Converts the string to integer class labels, with appropriate null_char_s
688 // in between if not in SimpleTextOutput mode. Returns false on failure.
689 /* static */
690 bool LSTMTrainer::EncodeString(const STRING& str, const UNICHARSET& unicharset,
691  const UnicharCompress* recoder, bool simple_text,
692  int null_char, GenericVector<int>* labels) {
693  if (str.c_str() == nullptr || str.length() <= 0) {
694  tprintf("Empty truth string!\n");
695  return false;
696  }
697  int err_index;
698  GenericVector<int> internal_labels;
699  labels->truncate(0);
700  if (!simple_text) labels->push_back(null_char);
701  std::string cleaned = unicharset.CleanupString(str.c_str());
702  if (unicharset.encode_string(cleaned.c_str(), true, &internal_labels, nullptr,
703  &err_index)) {
704  bool success = true;
705  for (int i = 0; i < internal_labels.size(); ++i) {
706  if (recoder != nullptr) {
707  // Re-encode labels via recoder.
708  RecodedCharID code;
709  int len = recoder->EncodeUnichar(internal_labels[i], &code);
710  if (len > 0) {
711  for (int j = 0; j < len; ++j) {
712  labels->push_back(code(j));
713  if (!simple_text) labels->push_back(null_char);
714  }
715  } else {
716  success = false;
717  err_index = 0;
718  break;
719  }
720  } else {
721  labels->push_back(internal_labels[i]);
722  if (!simple_text) labels->push_back(null_char);
723  }
724  }
725  if (success) return true;
726  }
727  tprintf("Encoding of string failed! Failure bytes:");
728  while (err_index < cleaned.size()) {
729  tprintf(" %x", cleaned[err_index++] & 0xff);
730  }
731  tprintf("\n");
732  return false;
733 }
734 
735 // Performs forward-backward on the given trainingdata.
736 // Returns a Trainability enum to indicate the suitability of the sample.
738  bool batch) {
739  NetworkIO fwd_outputs, targets;
740  Trainability trainable =
741  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
743  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
744  return trainable; // Sample was unusable.
745  }
746  bool debug = debug_interval_ > 0 &&
748  // Run backprop on the output.
749  NetworkIO bp_deltas;
750  if (network_->IsTraining() &&
751  (trainable != PERFECT ||
754  network_->Backward(debug, targets, &scratch_space_, &bp_deltas);
756  training_iteration_ + 1);
757  }
758 #ifndef GRAPHICS_DISABLED
759  if (debug_interval_ == 1 && debug_win_ != nullptr) {
761  }
762 #endif // GRAPHICS_DISABLED
763  // Roll the memory of past means.
765  return trainable;
766 }
767 
768 // Prepares the ground truth, runs forward, and prepares the targets.
769 // Returns a Trainability enum to indicate the suitability of the sample.
771  NetworkIO* fwd_outputs,
772  NetworkIO* targets) {
773  if (trainingdata == nullptr) {
774  tprintf("Null trainingdata.\n");
775  return UNENCODABLE;
776  }
777  // Ensure repeatability of random elements even across checkpoints.
778  bool debug = debug_interval_ > 0 &&
780  GenericVector<int> truth_labels;
781  if (!EncodeString(trainingdata->transcription(), &truth_labels)) {
782  tprintf("Can't encode transcription: '%s' in language '%s'\n",
783  trainingdata->transcription().c_str(),
784  trainingdata->language().c_str());
785  return UNENCODABLE;
786  }
787  bool upside_down = false;
788  if (randomly_rotate_) {
789  // This ensures consistent training results.
790  SetRandomSeed();
791  upside_down = randomizer_.SignedRand(1.0) > 0.0;
792  if (upside_down) {
793  // Modify the truth labels to match the rotation:
794  // Apart from space and null, increment the label. This is changes the
795  // script-id to the same script-id but upside-down.
796  // The labels need to be reversed in order, as the first is now the last.
797  for (int c = 0; c < truth_labels.size(); ++c) {
798  if (truth_labels[c] != UNICHAR_SPACE && truth_labels[c] != null_char_)
799  ++truth_labels[c];
800  }
801  truth_labels.reverse();
802  }
803  }
804  int w = 0;
805  while (w < truth_labels.size() &&
806  (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_))
807  ++w;
808  if (w == truth_labels.size()) {
809  tprintf("Blank transcription: %s\n",
810  trainingdata->transcription().c_str());
811  return UNENCODABLE;
812  }
813  float image_scale;
814  NetworkIO inputs;
815  bool invert = trainingdata->boxes().empty();
816  if (!RecognizeLine(*trainingdata, invert, debug, invert, upside_down,
817  &image_scale, &inputs, fwd_outputs)) {
818  tprintf("Image not trainable\n");
819  return UNENCODABLE;
820  }
821  targets->Resize(*fwd_outputs, network_->NumOutputs());
822  LossType loss_type = OutputLossType();
823  if (loss_type == LT_SOFTMAX) {
824  if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) {
825  tprintf("Compute simple targets failed!\n");
826  return UNENCODABLE;
827  }
828  } else if (loss_type == LT_CTC) {
829  if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) {
830  tprintf("Compute CTC targets failed!\n");
831  return UNENCODABLE;
832  }
833  } else {
834  tprintf("Logistic outputs not implemented yet!\n");
835  return UNENCODABLE;
836  }
837  GenericVector<int> ocr_labels;
838  GenericVector<int> xcoords;
839  LabelsFromOutputs(*fwd_outputs, &ocr_labels, &xcoords);
840  // CTC does not produce correct target labels to begin with.
841  if (loss_type != LT_CTC) {
842  LabelsFromOutputs(*targets, &truth_labels, &xcoords);
843  }
844  if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels,
845  *targets)) {
846  tprintf("Input width was %d\n", inputs.Width());
847  return UNENCODABLE;
848  }
849  STRING ocr_text = DecodeLabels(ocr_labels);
850  STRING truth_text = DecodeLabels(truth_labels);
851  targets->SubtractAllFromFloat(*fwd_outputs);
852  if (debug_interval_ != 0) {
853  if (truth_text != ocr_text) {
854  tprintf("Iteration %d: BEST OCR TEXT : %s\n",
855  training_iteration(), ocr_text.c_str());
856  }
857  }
858  double char_error = ComputeCharError(truth_labels, ocr_labels);
859  double word_error = ComputeWordError(&truth_text, &ocr_text);
860  double delta_error = ComputeErrorRates(*targets, char_error, word_error);
861  if (debug_interval_ != 0) {
862  tprintf("File %s line %d %s:\n", trainingdata->imagefilename().c_str(),
863  trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : "");
864  }
865  if (delta_error == 0.0) return PERFECT;
867  return TRAINABLE;
868 }
869 
870 // Writes the trainer to memory, so that the current training state can be
871 // restored. *this must always be the master trainer that retains the only
872 // copy of the training data and language model. trainer is the model that is
873 // actually serialized.
875  const LSTMTrainer* trainer,
876  GenericVector<char>* data) const {
877  TFile fp;
878  fp.OpenWrite(data);
879  return trainer->Serialize(serialize_amount, &mgr_, &fp);
880 }
881 
882 // Restores the model to *this.
884  const char* data, int size) {
885  if (size == 0) {
886  tprintf("Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n");
887  return false;
888  }
889  TFile fp;
890  fp.Open(data, size);
891  return DeSerialize(mgr, &fp);
892 }
893 
894 // Writes the full recognition traineddata to the given filename.
895 bool LSTMTrainer::SaveTraineddata(const STRING& filename) {
896  GenericVector<char> recognizer_data;
897  SaveRecognitionDump(&recognizer_data);
898  mgr_.OverwriteEntry(TESSDATA_LSTM, &recognizer_data[0],
899  recognizer_data.size());
900  return mgr_.SaveFile(filename, SaveDataToFile);
901 }
902 
903 // Writes the recognizer to memory, so that it can be used for testing later.
905  TFile fp;
906  fp.OpenWrite(data);
910 }
911 
912 // Returns a suitable filename for a training dump, based on the model_base_,
913 // best_error_rate_, best_iteration_ and training_iteration_.
915  STRING filename;
916  filename += model_base_.c_str();
917  filename.add_str_double("_", best_error_rate_);
918  filename.add_str_int("_", best_iteration_);
919  filename.add_str_int("_", training_iteration_);
920  filename += ".checkpoint";
921  return filename;
922 }
923 
924 // Fills the whole error buffer of the given type with the given value.
926  for (int i = 0; i < kRollingBufferSize_; ++i)
927  error_buffers_[type][i] = new_error;
928  error_rates_[type] = 100.0 * new_error;
929 }
930 
931 // Helper generates a map from each current recoder_ code (ie softmax index)
932 // to the corresponding old_recoder code, or -1 if there isn't one.
933 std::vector<int> LSTMTrainer::MapRecoder(
934  const UNICHARSET& old_chset, const UnicharCompress& old_recoder) const {
935  int num_new_codes = recoder_.code_range();
936  int num_new_unichars = GetUnicharset().size();
937  std::vector<int> code_map(num_new_codes, -1);
938  for (int c = 0; c < num_new_codes; ++c) {
939  int old_code = -1;
940  // Find all new unichar_ids that recode to something that includes c.
941  // The <= is to include the null char, which may be beyond the unicharset.
942  for (int uid = 0; uid <= num_new_unichars; ++uid) {
943  RecodedCharID codes;
944  int length = recoder_.EncodeUnichar(uid, &codes);
945  int code_index = 0;
946  while (code_index < length && codes(code_index) != c) ++code_index;
947  if (code_index == length) continue;
948  // The old unicharset must have the same unichar.
949  int old_uid =
950  uid < num_new_unichars
951  ? old_chset.unichar_to_id(GetUnicharset().id_to_unichar(uid))
952  : old_chset.size() - 1;
953  if (old_uid == INVALID_UNICHAR_ID) continue;
954  // The encoding of old_uid at the same code_index is the old code.
955  RecodedCharID old_codes;
956  if (code_index < old_recoder.EncodeUnichar(old_uid, &old_codes)) {
957  old_code = old_codes(code_index);
958  break;
959  }
960  }
961  code_map[c] = old_code;
962  }
963  return code_map;
964 }
965 
966 // Private version of InitCharSet above finishes the job after initializing
967 // the mgr_ data member.
971  // Initialize the unicharset and recoder.
972  if (!LoadCharsets(&mgr_)) {
973  ASSERT_HOST(
974  "Must provide a traineddata containing lstm_unicharset and"
975  " lstm_recoder!\n" != nullptr);
976  }
977  SetNullChar();
978 }
979 
980 // Helper computes and sets the null_char_.
983  : GetUnicharset().size();
984  RecodedCharID code;
986  null_char_ = code(0);
987 }
988 
989 // Factored sub-constructor sets up reasonable default values.
991  align_win_ = nullptr;
992  target_win_ = nullptr;
993  ctc_win_ = nullptr;
994  recon_win_ = nullptr;
996  training_stage_ = 0;
998  InitIterations();
999 }
1000 
1001 // Outputs the string and periodically displays the given network inputs
1002 // as an image in the given window, and the corresponding labels at the
1003 // corresponding x_starts.
1004 // Returns false if the truth string is empty.
1006  const ImageData& trainingdata,
1007  const NetworkIO& fwd_outputs,
1008  const GenericVector<int>& truth_labels,
1009  const NetworkIO& outputs) {
1010  const STRING& truth_text = DecodeLabels(truth_labels);
1011  if (truth_text.c_str() == nullptr || truth_text.length() <= 0) {
1012  tprintf("Empty truth string at decode time!\n");
1013  return false;
1014  }
1015  if (debug_interval_ != 0) {
1016  // Get class labels, xcoords and string.
1017  GenericVector<int> labels;
1018  GenericVector<int> xcoords;
1019  LabelsFromOutputs(outputs, &labels, &xcoords);
1020  STRING text = DecodeLabels(labels);
1021  tprintf("Iteration %d: GROUND TRUTH : %s\n",
1022  training_iteration(), truth_text.c_str());
1023  if (truth_text != text) {
1024  tprintf("Iteration %d: ALIGNED TRUTH : %s\n",
1025  training_iteration(), text.c_str());
1026  }
1027  if (debug_interval_ > 0 && training_iteration() % debug_interval_ == 0) {
1028  tprintf("TRAINING activation path for truth string %s\n",
1029  truth_text.c_str());
1030  DebugActivationPath(outputs, labels, xcoords);
1031  DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_);
1032  if (OutputLossType() == LT_CTC) {
1033  DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_);
1034  DisplayTargets(outputs, "CTC Targets", &target_win_);
1035  }
1036  }
1037  }
1038  return true;
1039 }
1040 
1041 // Displays the network targets as line a line graph.
1043  const char* window_name, ScrollView** window) {
1044 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics.
1045  int width = targets.Width();
1046  int num_features = targets.NumFeatures();
1047  Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale,
1048  window);
1049  for (int c = 0; c < num_features; ++c) {
1050  int color = c % (ScrollView::GREEN_YELLOW - 1) + 2;
1051  (*window)->Pen(static_cast<ScrollView::Color>(color));
1052  int start_t = -1;
1053  for (int t = 0; t < width; ++t) {
1054  double target = targets.f(t)[c];
1055  target *= kTargetYScale;
1056  if (target >= 1) {
1057  if (start_t < 0) {
1058  (*window)->SetCursor(t - 1, 0);
1059  start_t = t;
1060  }
1061  (*window)->DrawTo(t, target);
1062  } else if (start_t >= 0) {
1063  (*window)->DrawTo(t, 0);
1064  (*window)->DrawTo(start_t - 1, 0);
1065  start_t = -1;
1066  }
1067  }
1068  if (start_t >= 0) {
1069  (*window)->DrawTo(width, 0);
1070  (*window)->DrawTo(start_t - 1, 0);
1071  }
1072  }
1073  (*window)->Update();
1074 #endif // GRAPHICS_DISABLED
1075 }
1076 
1077 // Builds a no-compromises target where the first positions should be the
1078 // truth labels and the rest is padded with the null_char_.
1080  const GenericVector<int>& truth_labels,
1081  NetworkIO* targets) {
1082  if (truth_labels.size() > targets->Width()) {
1083  tprintf("Error: transcription %s too long to fit into target of width %d\n",
1084  DecodeLabels(truth_labels).c_str(), targets->Width());
1085  return false;
1086  }
1087  for (int i = 0; i < truth_labels.size() && i < targets->Width(); ++i) {
1088  targets->SetActivations(i, truth_labels[i], 1.0);
1089  }
1090  for (int i = truth_labels.size(); i < targets->Width(); ++i) {
1091  targets->SetActivations(i, null_char_, 1.0);
1092  }
1093  return true;
1094 }
1095 
1096 // Builds a target using standard CTC. truth_labels should be pre-padded with
1097 // nulls wherever desired. They don't have to be between all labels.
1098 // outputs is input-output, as it gets clipped to minimum probability.
1100  NetworkIO* outputs, NetworkIO* targets) {
1101  // Bottom-clip outputs to a minimum probability.
1102  CTC::NormalizeProbs(outputs);
1103  return CTC::ComputeCTCTargets(truth_labels, null_char_,
1104  outputs->float_array(), targets);
1105 }
1106 
1107 // Computes network errors, and stores the results in the rolling buffers,
1108 // along with the supplied text_error.
1109 // Returns the delta error of the current sample (not running average.)
1111  double char_error, double word_error) {
1113  // Delta error is the fraction of timesteps with >0.5 error in the top choice
1114  // score. If zero, then the top choice characters are guaranteed correct,
1115  // even when there is residue in the RMS error.
1116  double delta_error = ComputeWinnerError(deltas);
1117  UpdateErrorBuffer(delta_error, ET_DELTA);
1118  UpdateErrorBuffer(word_error, ET_WORD_RECERR);
1119  UpdateErrorBuffer(char_error, ET_CHAR_ERROR);
1120  // Skip ratio measures the difference between sample_iteration_ and
1121  // training_iteration_, which reflects the number of unusable samples,
1122  // usually due to unencodable truth text, or the text not fitting in the
1123  // space for the output.
1124  double skip_count = sample_iteration_ - prev_sample_iteration_;
1125  UpdateErrorBuffer(skip_count, ET_SKIP_RATIO);
1126  return delta_error;
1127 }
1128 
1129 // Computes the network activation RMS error rate.
1131  double total_error = 0.0;
1132  int width = deltas.Width();
1133  int num_classes = deltas.NumFeatures();
1134  for (int t = 0; t < width; ++t) {
1135  const float* class_errs = deltas.f(t);
1136  for (int c = 0; c < num_classes; ++c) {
1137  double error = class_errs[c];
1138  total_error += error * error;
1139  }
1140  }
1141  return sqrt(total_error / (width * num_classes));
1142 }
1143 
1144 // Computes network activation winner error rate. (Number of values that are
1145 // in error by >= 0.5 divided by number of time-steps.) More closely related
1146 // to final character error than RMS, but still directly calculable from
1147 // just the deltas. Because of the binary nature of the targets, zero winner
1148 // error is a sufficient but not necessary condition for zero char error.
1150  int num_errors = 0;
1151  int width = deltas.Width();
1152  int num_classes = deltas.NumFeatures();
1153  for (int t = 0; t < width; ++t) {
1154  const float* class_errs = deltas.f(t);
1155  for (int c = 0; c < num_classes; ++c) {
1156  float abs_delta = fabs(class_errs[c]);
1157  // TODO(rays) Filtering cases where the delta is very large to cut out
1158  // GT errors doesn't work. Find a better way or get better truth.
1159  if (0.5 <= abs_delta)
1160  ++num_errors;
1161  }
1162  }
1163  return static_cast<double>(num_errors) / width;
1164 }
1165 
1166 // Computes a very simple bag of chars char error rate.
1168  const GenericVector<int>& ocr_str) {
1169  GenericVector<int> label_counts;
1170  label_counts.init_to_size(NumOutputs(), 0);
1171  int truth_size = 0;
1172  for (int i = 0; i < truth_str.size(); ++i) {
1173  if (truth_str[i] != null_char_) {
1174  ++label_counts[truth_str[i]];
1175  ++truth_size;
1176  }
1177  }
1178  for (int i = 0; i < ocr_str.size(); ++i) {
1179  if (ocr_str[i] != null_char_) {
1180  --label_counts[ocr_str[i]];
1181  }
1182  }
1183  int char_errors = 0;
1184  for (int i = 0; i < label_counts.size(); ++i) {
1185  char_errors += abs(label_counts[i]);
1186  }
1187  if (truth_size == 0) {
1188  return (char_errors == 0) ? 0.0 : 1.0;
1189  }
1190  return static_cast<double>(char_errors) / truth_size;
1191 }
1192 
1193 // Computes word recall error rate using a very simple bag of words algorithm.
1194 // NOTE that this is destructive on both input strings.
1195 double LSTMTrainer::ComputeWordError(STRING* truth_str, STRING* ocr_str) {
1196  using StrMap = std::unordered_map<std::string, int, std::hash<std::string>>;
1197  GenericVector<STRING> truth_words, ocr_words;
1198  truth_str->split(' ', &truth_words);
1199  if (truth_words.empty()) return 0.0;
1200  ocr_str->split(' ', &ocr_words);
1201  StrMap word_counts;
1202  for (int i = 0; i < truth_words.size(); ++i) {
1203  std::string truth_word(truth_words[i].c_str());
1204  auto it = word_counts.find(truth_word);
1205  if (it == word_counts.end())
1206  word_counts.insert(std::make_pair(truth_word, 1));
1207  else
1208  ++it->second;
1209  }
1210  for (int i = 0; i < ocr_words.size(); ++i) {
1211  std::string ocr_word(ocr_words[i].c_str());
1212  auto it = word_counts.find(ocr_word);
1213  if (it == word_counts.end())
1214  word_counts.insert(std::make_pair(ocr_word, -1));
1215  else
1216  --it->second;
1217  }
1218  int word_recall_errs = 0;
1219  for (StrMap::const_iterator it = word_counts.begin(); it != word_counts.end();
1220  ++it) {
1221  if (it->second > 0) word_recall_errs += it->second;
1222  }
1223  return static_cast<double>(word_recall_errs) / truth_words.size();
1224 }
1225 
1226 // Updates the error buffer and corresponding mean of the given type with
1227 // the new_error.
1230  error_buffers_[type][index] = new_error;
1231  // Compute the mean error.
1232  int mean_count = std::min(training_iteration_ + 1, error_buffers_[type].size());
1233  double buffer_sum = 0.0;
1234  for (int i = 0; i < mean_count; ++i) buffer_sum += error_buffers_[type][i];
1235  double mean = buffer_sum / mean_count;
1236  // Trim precision to 1/1000 of 1%.
1237  error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0;
1238 }
1239 
1240 // Rolls error buffers and reports the current means.
1243  if (NewSingleError(ET_DELTA) > 0.0)
1245  else
1248  if (debug_interval_ != 0) {
1249  tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1253  }
1254 }
1255 
1256 // Given that error_rate is either a new min or max, updates the best/worst
1257 // error rates, and record of progress.
1258 // Tester is an externally supplied callback function that tests on some
1259 // data set with a given model and records the error rates in a graph.
1260 STRING LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate,
1261  const GenericVector<char>& model_data,
1262  TestCallback tester) {
1263  if (error_rate > best_error_rate_
1264  && iteration < best_iteration_ + kErrorGraphInterval) {
1265  // Too soon to record a new point.
1266  if (tester != nullptr && !worst_model_data_.empty()) {
1269  return tester(worst_iteration_, nullptr, mgr_, CurrentTrainingStage());
1270  } else {
1271  return "";
1272  }
1273  }
1274  STRING result;
1275  // NOTE: there are 2 asymmetries here:
1276  // 1. We are computing the global minimum, but the local maximum in between.
1277  // 2. If the tester returns an empty string, indicating that it is busy,
1278  // call it repeatedly on new local maxima to test the previous min, but
1279  // not the other way around, as there is little point testing the maxima
1280  // between very frequent minima.
1281  if (error_rate < best_error_rate_) {
1282  // This is a new (global) minimum.
1283  if (tester != nullptr && !worst_model_data_.empty()) {
1286  result = tester(worst_iteration_, worst_error_rates_, mgr_,
1289  best_model_data_ = model_data;
1290  }
1291  best_error_rate_ = error_rate;
1292  memcpy(best_error_rates_, error_rates_, sizeof(error_rates_));
1293  best_iteration_ = iteration;
1294  best_error_history_.push_back(error_rate);
1295  best_error_iterations_.push_back(iteration);
1296  // Compute 2% decay time.
1297  double two_percent_more = error_rate + 2.0;
1298  int i;
1299  for (i = best_error_history_.size() - 1;
1300  i >= 0 && best_error_history_[i] < two_percent_more; --i) {
1301  }
1302  int old_iteration = i >= 0 ? best_error_iterations_[i] : 0;
1303  improvement_steps_ = iteration - old_iteration;
1304  tprintf("2 Percent improvement time=%d, best error was %g @ %d\n",
1305  improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0,
1306  old_iteration);
1307  } else if (error_rate > best_error_rate_) {
1308  // This is a new (local) maximum.
1309  if (tester != nullptr) {
1310  if (!best_model_data_.empty()) {
1313  result = tester(best_iteration_, best_error_rates_, mgr_,
1315  } else if (!worst_model_data_.empty()) {
1316  // Allow for multiple data points with "worst" error rate.
1319  result = tester(worst_iteration_, worst_error_rates_, mgr_,
1321  }
1322  if (result.length() > 0)
1324  worst_model_data_ = model_data;
1325  }
1326  }
1327  worst_error_rate_ = error_rate;
1328  memcpy(worst_error_rates_, error_rates_, sizeof(error_rates_));
1329  worst_iteration_ = iteration;
1330  return result;
1331 }
1332 
1333 } // namespace tesseract.
UNICHARSET::load_from_file
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:378
tesseract::ET_COUNT
Definition: lstmtrainer.h:43
tesseract::LSTMTrainer::InitNetwork
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
Definition: lstmtrainer.cpp:146
string
std::string string
Definition: equationdetect_test.cc:21
tesseract::LSTMTrainer::DebugLSTMTraining
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
Definition: lstmtrainer.cpp:1005
tesseract::TS_ENABLED
Definition: network.h:95
tesseract::RecodeBeamSearch::kMinCertainty
static constexpr float kMinCertainty
Definition: recodebeam.h:252
tesseract::StaticShape
Definition: static_shape.h:38
tesseract::kSubTrainerMarginFraction
const double kSubTrainerMarginFraction
Definition: lstmtrainer.cpp:51
tesseract::LSTMRecognizer::learning_rate_
float learning_rate_
Definition: lstmrecognizer.h:283
tesseract::LSTMTrainer::LoadAllTrainingData
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
Definition: lstmtrainer.cpp:272
ScrollView
Definition: scrollview.h:97
tesseract::STR_NONE
Definition: lstmtrainer.h:64
tesseract::LSTMTrainer::randomly_rotate_
bool randomly_rotate_
Definition: lstmtrainer.h:399
tesseract::STR_REPLACED
Definition: lstmtrainer.h:66
tesseract::kMinDivergenceRate
const double kMinDivergenceRate
Definition: lstmtrainer.cpp:46
tesseract::LSTMRecognizer::dict_
Dict * dict_
Definition: lstmrecognizer.h:292
tesseract::LT_CTC
Definition: static_shape.h:31
tesseract::Network::Backward
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
tesseract::LSTMRecognizer::DebugActivationPath
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
Definition: lstmrecognizer.cpp:392
tesseract::LSTMTrainer::learning_iteration_
int learning_iteration_
Definition: lstmtrainer.h:443
tesseract::LSTMTrainer::PrepareForBackward
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
Definition: lstmtrainer.cpp:770
STRING::add_str_int
void add_str_int(const char *str, int number)
Definition: strngs.cpp:370
tesseract::LSTMTrainer::UpdateSubtrainer
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
Definition: lstmtrainer.cpp:519
tesseract::LSTMTrainer::SaveTrainingDump
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
Definition: lstmtrainer.cpp:874
SVET_CLICK
Definition: scrollview.h:47
boxread.h
UNICHARSET::encode_string
bool encode_string(const char *str, bool give_up_on_failure, GenericVector< UNICHAR_ID > *encoding, GenericVector< char > *lengths, int *encoded_length) const
Definition: unicharset.cpp:258
tesseract::TessdataManager::SetVersionString
void SetVersionString(const std::string &v_str)
Definition: tessdatamanager.cpp:239
tesseract::Network::SetEnableTraining
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:110
tesseract::LSTMTrainer::SaveRecognitionDump
void SaveRecognitionDump(GenericVector< char > *data) const
Definition: lstmtrainer.cpp:904
tesseract::SerializeAmount
SerializeAmount
Definition: lstmtrainer.h:56
tesseract::RecodeBeamSearch
Definition: recodebeam.h:180
tesseract::TessdataManager
Definition: tessdatamanager.h:126
lstmtrainer.h
tesseract::LSTMTrainer::mgr_
TessdataManager mgr_
Definition: lstmtrainer.h:462
tesseract::LSTMRecognizer::ScaleLayerLearningRate
void ScaleLayerLearningRate(const STRING &id, double factor)
Definition: lstmrecognizer.h:116
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:87
tesseract::UnicharCompress::DeSerialize
bool DeSerialize(TFile *fp)
Definition: unicharcompress.cpp:305
tesseract::LSTMTrainer::best_error_rates_
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:411
tesseract::LSTMTrainer::num_training_stages_
int num_training_stages_
Definition: lstmtrainer.h:404
tesseract::LSTMTrainer::MapRecoder
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
Definition: lstmtrainer.cpp:933
tesseract::LoadDataFromFile
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
Definition: genericvector.h:341
tesseract::kTargetXScale
const int kTargetXScale
Definition: lstmtrainer.cpp:71
tesseract::LSTMTrainer::prev_sample_iteration_
int prev_sample_iteration_
Definition: lstmtrainer.h:445
tesseract::LSTMRecognizer::training_iteration
int training_iteration() const
Definition: lstmrecognizer.h:60
tesseract::LSTMTrainer::worst_error_rate_
double worst_error_rate_
Definition: lstmtrainer.h:415
tesseract::TessdataManager::VersionString
std::string VersionString() const
Definition: tessdatamanager.cpp:233
tesseract::LSTMTrainer::checkpoint_name_
STRING checkpoint_name_
Definition: lstmtrainer.h:397
tesseract::DocumentCache::LoadDocuments
bool LoadDocuments(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, FileReader reader)
Definition: imagedata.cpp:566
tesseract::LSTMTrainer::best_error_rate_
double best_error_rate_
Definition: lstmtrainer.h:409
tesseract::countof
constexpr size_t countof(T const (&)[N]) noexcept
Definition: serialis.h:41
tesseract::LSTMTrainer::DumpFilename
STRING DumpFilename() const
Definition: lstmtrainer.cpp:914
tesseract::LSTMRecognizer::randomizer_
TRand randomizer_
Definition: lstmrecognizer.h:289
tesseract::LSTMTrainer::debug_interval_
int debug_interval_
Definition: lstmtrainer.h:391
tesseract::LSTMRecognizer::learning_rate
double learning_rate() const
Definition: lstmrecognizer.h:62
tesseract::LSTMTrainer::perfect_delay_
int perfect_delay_
Definition: lstmtrainer.h:451
tesseract::kNumAdjustmentIterations
const int kNumAdjustmentIterations
Definition: lstmtrainer.cpp:55
tesseract::Network::RemapOutputs
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
Definition: network.h:186
STRING
Definition: strngs.h:45
tesseract::NetworkIO::Width
int Width() const
Definition: networkio.h:107
recodebeam.h
tesseract::UnicharCompress::code_range
int code_range() const
Definition: unicharcompress.h:161
tesseract::LSTMTrainer::error_buffers_
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:458
GenericVector::Serialize
bool Serialize(FILE *fp) const
Definition: genericvector.h:929
tesseract::ImageData::imagefilename
const STRING & imagefilename() const
Definition: imagedata.h:125
tesseract::SaveDataToFile
bool SaveDataToFile(const GenericVector< char > &data, const char *filename)
Definition: genericvector.h:362
tesseract::LSTMRecognizer::sample_iteration_
int32_t sample_iteration_
Definition: lstmrecognizer.h:278
tesseract::NetworkIO::float_array
const GENERIC_2D_ARRAY< float > & float_array() const
Definition: networkio.h:139
tesseract::LSTMTrainer::EncodeString
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:232
IntCastRounded
int IntCastRounded(double x)
Definition: helpers.h:173
tesseract::RecodeBeamSearch::Decode
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:78
tesseract::LSTMRecognizer::network_str_
STRING network_str_
Definition: lstmrecognizer.h:271
tesseract::LSTMRecognizer::EnumerateLayers
GenericVector< STRING > EnumerateLayers() const
Definition: lstmrecognizer.h:79
tesseract::LSTMRecognizer::LabelsFromOutputs
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
Definition: lstmrecognizer.cpp:462
tesseract::kMinStartedErrorRate
const int kMinStartedErrorRate
Definition: lstmtrainer.cpp:61
tesseract::LSTMRecognizer::DisplayForward
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
Definition: lstmrecognizer.cpp:349
tesseract::LSTMTrainer::DebugNetwork
void DebugNetwork()
Definition: lstmtrainer.cpp:265
tesseract::LSTMRecognizer::GetUnicharset
const UNICHARSET & GetUnicharset() const
Definition: lstmrecognizer.h:132
tesseract::ImageData
Definition: imagedata.h:104
tesseract::LSTMTrainer::model_base_
STRING model_base_
Definition: lstmtrainer.h:395
tesseract::Network::DebugWeights
virtual void DebugWeights()=0
tesseract::LSTMTrainer::TrainOnLine
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:245
tesseract::LSTMTrainer::best_error_history_
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:436
tesseract::TFile::Open
bool Open(const STRING &filename, FileReader reader)
Definition: serialis.cpp:210
tesseract::Network::TestFlag
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
tesseract::ET_WORD_RECERR
Definition: lstmtrainer.h:40
tesseract::Network::IsTraining
bool IsTraining() const
Definition: network.h:115
ratngs.h
tesseract::PERFECT
Definition: lstmtrainer.h:49
tesseract::LSTMRecognizer::GetLayerLearningRate
float GetLayerLearningRate(const STRING &id) const
Definition: lstmrecognizer.h:94
tesseract::LSTMTrainer::best_trainer_
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:426
GenericVector::reverse
void reverse()
Definition: genericvector.h:215
tesseract::LSTMTrainer::FillErrorBuffer
void FillErrorBuffer(double new_error, ErrorTypes type)
Definition: lstmtrainer.cpp:925
tesseract::LSTMTrainer::StartSubtrainer
void StartSubtrainer(STRING *log_msg)
Definition: lstmtrainer.cpp:489
tesseract::ErrorTypes
ErrorTypes
Definition: lstmtrainer.h:37
ctc.h
tesseract::NetworkIO::SetActivations
void SetActivations(int t, int label, float ok_score)
Definition: networkio.cpp:537
tesseract::TessdataManager::OverwriteEntry
void OverwriteEntry(TessdataType type, const char *data, int size)
Definition: tessdatamanager.cpp:145
tesseract::LSTMRecognizer::LoadCharsets
bool LoadCharsets(const TessdataManager *mgr)
Definition: lstmrecognizer.cpp:133
tesseract::Network::OutputShape
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133
tesseract::FULL
Definition: lstmtrainer.h:59
tesseract::ImageData::boxes
const GenericVector< TBOX > & boxes() const
Definition: imagedata.h:149
tesseract::LSTMRecognizer::SetIteration
void SetIteration(int iteration)
Definition: lstmrecognizer.h:142
tesseract::LSTMTrainer::error_rate_of_last_saved_best_
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:431
tesseract::LSTMTrainer::Serialize
bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
Definition: lstmtrainer.cpp:403
tesseract::LSTMTrainer::NewSingleError
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:140
tesseract::LSTMRecognizer::DeSerialize
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
Definition: lstmrecognizer.cpp:108
UNICHAR_BROKEN
Definition: unicharset.h:36
tesseract::LSTMTrainer::RollErrorBuffers
void RollErrorBuffers()
Definition: lstmtrainer.cpp:1241
GenericVector::push_back
int push_back(T object)
Definition: genericvector.h:799
tesseract::LSTMRecognizer::RecognizeLine
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
Definition: lstmrecognizer.cpp:187
tesseract::LSTMRecognizer::scratch_space_
NetworkScratch scratch_space_
Definition: lstmrecognizer.h:290
tesseract::LSTMTrainer::align_win_
ScrollView * align_win_
Definition: lstmtrainer.h:383
tesseract::LSTMTrainer::ReadLocalTrainingDump
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
Definition: lstmtrainer.cpp:883
tesseract::LossType
LossType
Definition: static_shape.h:29
tesseract::LSTMRecognizer::training_flags_
int32_t training_flags_
Definition: lstmrecognizer.h:274
tesseract::ImageData::transcription
const STRING & transcription() const
Definition: imagedata.h:146
tesseract::LSTMRecognizer::null_char
int null_char() const
Definition: lstmrecognizer.h:145
tesseract::LSTMRecognizer::SetRandomSeed
void SetRandomSeed()
Definition: lstmrecognizer.h:217
tesseract::LSTMTrainer::CharError
double CharError() const
Definition: lstmtrainer.h:125
STRING::c_str
const char * c_str() const
Definition: strngs.cpp:192
tesseract::LSTMRecognizer::momentum_
float momentum_
Definition: lstmrecognizer.h:284
tesseract::TFile::DeSerialize
bool DeSerialize(char *data, size_t count=1)
Definition: serialis.cpp:117
tesseract::LSTMRecognizer::debug_win_
ScrollView * debug_win_
Definition: lstmrecognizer.h:298
GenericVector::DeSerialize
bool DeSerialize(bool swap, FILE *fp)
Definition: genericvector.h:954
tesseract::TFile::Serialize
bool Serialize(const char *data, size_t count=1)
Definition: serialis.cpp:161
tesseract::LSTMTrainer::best_model_data_
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:423
tesseract::LSTMTrainer::UpdateErrorBuffer
void UpdateErrorBuffer(double new_error, ErrorTypes type)
Definition: lstmtrainer.cpp:1228
tesseract::LSTMTrainer::SetNullChar
void SetNullChar()
Definition: lstmtrainer.cpp:981
tesseract::LSTMTrainer::learning_iteration
int learning_iteration() const
Definition: lstmtrainer.h:135
tesseract::LSTMTrainer::ComputeRMSError
double ComputeRMSError(const NetworkIO &deltas)
Definition: lstmtrainer.cpp:1130
tesseract::TessdataManager::GetComponent
bool GetComponent(TessdataType type, TFile *fp)
Definition: tessdatamanager.cpp:216
UNICHARSET::unichar_to_id
UNICHAR_ID unichar_to_id(const char *const unichar_repr) const
Definition: unicharset.cpp:209
tesseract::LSTMRecognizer::OutputLossType
LossType OutputLossType() const
Definition: lstmrecognizer.h:63
UNICHAR_SPACE
Definition: unicharset.h:34
tesseract::TESSDATA_LSTM_RECODER
Definition: tessdatamanager.h:79
tesseract::LSTMRecognizer::ScaleLearningRate
void ScaleLearningRate(double factor)
Definition: lstmrecognizer.h:105
networkbuilder.h
tesseract::NetworkIO::f
float * f(int t)
Definition: networkio.h:115
tesseract::Network::InputShape
virtual StaticShape InputShape() const
Definition: network.h:127
tesseract::TFile
Definition: serialis.h:75
tesseract::LSTMTrainer::ComputeErrorRates
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
Definition: lstmtrainer.cpp:1110
tesseract::STR_UPDATED
Definition: lstmtrainer.h:65
tesseract::kNumPagesPerBatch
const int kNumPagesPerBatch
Definition: lstmtrainer.cpp:59
GenericVector::empty
bool empty() const
Definition: genericvector.h:86
UNICHARSET
Definition: unicharset.h:145
tesseract::LSTMTrainer::improvement_steps_
int32_t improvement_steps_
Definition: lstmtrainer.h:439
tesseract::NetworkIO
Definition: networkio.h:39
tesseract::LSTMRecognizer::GetLayer
Network * GetLayer(const STRING &id) const
Definition: lstmrecognizer.h:87
tesseract::LSTMTrainer::ctc_win_
ScrollView * ctc_win_
Definition: lstmtrainer.h:387
tesseract::kStageTransitionThreshold
const double kStageTransitionThreshold
Definition: lstmtrainer.cpp:63
tesseract::LSTMRecognizer::Serialize
bool Serialize(const TessdataManager *mgr, TFile *fp) const
Definition: lstmrecognizer.cpp:89
tesseract::LSTMRecognizer::recoder_
UnicharCompress recoder_
Definition: lstmrecognizer.h:268
UNICHARSET::CleanupString
static std::string CleanupString(const char *utf8_str)
Definition: unicharset.h:246
tesseract::Trainability
Trainability
Definition: lstmtrainer.h:47
tesseract::LSTMRecognizer::training_iteration_
int32_t training_iteration_
Definition: lstmrecognizer.h:276
tesseract::kTargetYScale
const int kTargetYScale
Definition: lstmtrainer.cpp:72
tesseract::TS_RE_ENABLE
Definition: network.h:99
tesseract::LSTMTrainer::worst_iteration_
int worst_iteration_
Definition: lstmtrainer.h:419
tesseract
Definition: baseapi.h:65
tesseract::LSTMTrainer::TryLoadingCheckpoint
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
Definition: lstmtrainer.cpp:103
tesseract::LSTMTrainer::ComputeWordError
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
Definition: lstmtrainer.cpp:1195
tesseract::ImageData::page_number
int page_number() const
Definition: imagedata.h:131
tesseract::LSTMTrainer::target_win_
ScrollView * target_win_
Definition: lstmtrainer.h:385
tesseract::LSTMTrainer::stall_iteration_
int stall_iteration_
Definition: lstmtrainer.h:421
tesseract::LSTMTrainer::ReadTrainingDump
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:277
tesseract::LSTMTrainer::training_stage_
int training_stage_
Definition: lstmtrainer.h:433
tesseract::SubTrainerResult
SubTrainerResult
Definition: lstmtrainer.h:63
tesseract::LSTMTrainer::GridSearchDictParams
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
Definition: lstmtrainer.cpp:215
tesseract::kLearningRateDecay
const double kLearningRateDecay
Definition: lstmtrainer.cpp:53
tesseract::LSTMRecognizer::SimpleTextOutput
bool SimpleTextOutput() const
Definition: lstmrecognizer.h:69
tesseract::LSTMTrainer::DisplayTargets
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
Definition: lstmtrainer.cpp:1042
tesseract::LSTMRecognizer::null_char_
int32_t null_char_
Definition: lstmrecognizer.h:281
UNICHARSET::has_special_codes
bool has_special_codes() const
Definition: unicharset.h:712
tesseract::Network::NumOutputs
int NumOutputs() const
Definition: network.h:123
tesseract::ET_CHAR_ERROR
Definition: lstmtrainer.h:41
callcpp.h
tprintf.h
tesseract::RecodedCharID
Definition: unicharcompress.h:34
tesseract::TRand::SignedRand
double SignedRand(double range)
Definition: helpers.h:85
tesseract::LSTMTrainer::best_iteration_
int best_iteration_
Definition: lstmtrainer.h:413
GenericVector< char >
tesseract::LSTMTrainer::InitTensorFlowNetwork
int InitTensorFlowNetwork(const std::string &tf_proto)
tesseract::LSTMTrainer::sub_trainer_
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:429
tesseract::LSTMTrainer::DeSerialize
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
Definition: lstmtrainer.cpp:440
tesseract::kImprovementFraction
const double kImprovementFraction
Definition: lstmtrainer.cpp:67
tesseract::NetworkIO::SubtractAllFromFloat
void SubtractAllFromFloat(const NetworkIO &src)
Definition: networkio.cpp:824
tesseract::LT_SOFTMAX
Definition: static_shape.h:32
tesseract::Network
Definition: network.h:105
ScrollView::AwaitEvent
SVEvent * AwaitEvent(SVEventType type)
Definition: scrollview.cpp:443
tesseract::HI_PRECISION_ERR
Definition: lstmtrainer.h:51
tesseract::ET_RMS
Definition: lstmtrainer.h:38
STRING::length
int32_t length() const
Definition: strngs.cpp:187
tesseract::kBestCheckpointFraction
const double kBestCheckpointFraction
Definition: lstmtrainer.cpp:69
tesseract::LSTMTrainer::EmptyConstructor
void EmptyConstructor()
Definition: lstmtrainer.cpp:990
tesseract::LSTMTrainer
Definition: lstmtrainer.h:79
imagedata.h
tesseract::NetworkIO::Resize
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
tesseract::Network::name
const STRING & name() const
Definition: network.h:138
tesseract::Network::ClearWindow
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:312
tesseract::Network::Update
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
Definition: network.h:230
GenericVector::truncate
void truncate(int size)
Definition: genericvector.h:132
tesseract::LSTMTrainer::ComputeCTCTargets
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
Definition: lstmtrainer.cpp:1099
tfnetwork.h
tesseract::LSTMTrainer::recon_win_
ScrollView * recon_win_
Definition: lstmtrainer.h:389
tesseract::LSTMTrainer::InitCharSet
void InitCharSet()
Definition: lstmtrainer.cpp:968
tesseract::TessdataManager::Init
bool Init(const char *data_file_name)
Definition: tessdatamanager.cpp:97
tesseract::CTC::NormalizeProbs
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36
STRING::add_str_double
void add_str_double(const char *str, double number)
Definition: strngs.cpp:380
tesseract::LSTMTrainer::ReduceLearningRates
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
Definition: lstmtrainer.cpp:562
tesseract::CachingStrategy
CachingStrategy
Definition: imagedata.h:41
tesseract::TS_TEMP_DISABLE
Definition: network.h:97
tesseract::kMinStallIterations
const int kMinStallIterations
Definition: lstmtrainer.cpp:48
GenericVector::init_to_size
void init_to_size(int size, const T &t)
Definition: genericvector.h:706
tesseract::ImageData::language
const STRING & language() const
Definition: imagedata.h:140
tesseract::LSTMTrainer::last_perfect_training_iteration_
int last_perfect_training_iteration_
Definition: lstmtrainer.h:454
tesseract::NetworkBuilder::InitNetwork
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
Definition: networkbuilder.cpp:45
tesseract::NetworkIO::AnySuspiciousTruth
bool AnySuspiciousTruth(float confidence_thr) const
Definition: networkio.cpp:579
tesseract::LSTMRecognizer::NumOutputs
int NumOutputs() const
Definition: lstmrecognizer.h:59
tesseract::LSTMTrainer::CurrentTrainingStage
int CurrentTrainingStage() const
Definition: lstmtrainer.h:197
tesseract::LSTMTrainer::ReduceLayerLearningRates
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
Definition: lstmtrainer.cpp:581
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
tesseract::LSTMRecognizer::network_
Network * network_
Definition: lstmrecognizer.h:261
tesseract::LSTMTrainer::best_error_iterations_
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:437
tesstrain_utils.type
type
Definition: tesstrain_utils.py:141
tesseract::kErrorGraphInterval
const int kErrorGraphInterval
Definition: lstmtrainer.cpp:57
tesseract::LSTMTrainer::UpdateErrorGraph
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
Definition: lstmtrainer.cpp:1260
tesseract::Network::spec
virtual STRING spec() const
Definition: network.h:141
tesseract::LSTMTrainer::worst_model_data_
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:424
tesseract::ET_SKIP_RATIO
Definition: lstmtrainer.h:42
tesseract::LIGHT
Definition: lstmtrainer.h:57
tesseract::NF_LAYER_SPECIFIC_LR
Definition: network.h:87
tesseract::LSTMTrainer::worst_error_rates_
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:417
tesseract::kHighConfidence
const double kHighConfidence
Definition: lstmtrainer.cpp:65
tesseract::LSTMTrainer::kRollingBufferSize_
static const int kRollingBufferSize_
Definition: lstmtrainer.h:457
tesseract::LSTMTrainer::error_rates_
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:460
tesseract::LSTMTrainer::~LSTMTrainer
virtual ~LSTMTrainer()
Definition: lstmtrainer.cpp:93
tesseract::DocumentCache::Clear
void Clear()
Definition: imagedata.h:326
tesseract::LSTMTrainer::MaintainCheckpoints
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
Definition: lstmtrainer.cpp:285
tesseract::LSTMTrainer::training_data_
DocumentCache training_data_
Definition: lstmtrainer.h:400
tesseract::LSTMTrainer::TransitionTrainingStage
bool TransitionTrainingStage(float error_threshold)
Definition: lstmtrainer.cpp:393
tesseract::UnicharCompress::EncodeUnichar
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
Definition: unicharcompress.cpp:283
tesseract::LSTMTrainer::PrepareLogMsg
void PrepareLogMsg(STRING *log_msg) const
Definition: lstmtrainer.cpp:372
tesseract::NetworkIO::NumFeatures
int NumFeatures() const
Definition: networkio.h:111
tesseract::TF_COMPRESS_UNICHARSET
Definition: lstmrecognizer.h:48
tesseract::UNENCODABLE
Definition: lstmtrainer.h:50
tesseract::TESSDATA_LSTM
Definition: tessdatamanager.h:74
tesseract::LSTMTrainer::ComputeTextTargets
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
Definition: lstmtrainer.cpp:1079
tesseract::UnicharCompress
Definition: unicharcompress.h:128
GenericVector::size
int size() const
Definition: genericvector.h:71
tesseract::Network::num_weights
int num_weights() const
Definition: network.h:119
ScrollView::GREEN_YELLOW
Definition: scrollview.h:149
tesseract::LSTMRecognizer::adam_beta_
float adam_beta_
Definition: lstmrecognizer.h:286
tesseract::TFile::OpenWrite
void OpenWrite(GenericVector< char > *data)
Definition: serialis.cpp:309
tesseract::LSTMTrainer::LSTMTrainer
LSTMTrainer()
Definition: lstmtrainer.cpp:74
tesseract::LSTMTrainer::SaveTraineddata
bool SaveTraineddata(const STRING &filename)
Definition: lstmtrainer.cpp:895
tesseract::LSTMTrainer::InitIterations
void InitIterations()
Definition: lstmtrainer.cpp:190
tesseract::LSTMTrainer::checkpoint_iteration_
int checkpoint_iteration_
Definition: lstmtrainer.h:393
tesseract::NO_BEST_TRAINER
Definition: lstmtrainer.h:58
tesseract::Network::CountAlternators
virtual void CountAlternators(const Network &other, double *same, double *changed) const
Definition: network.h:235
tesseract::TRAINABLE
Definition: lstmtrainer.h:48
tesseract::TESSDATA_LSTM_UNICHARSET
Definition: tessdatamanager.h:78
tesseract::CTC::ComputeCTCTargets
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:54
tesseract::TestCallback
std::function< STRING(int, const double *, const TessdataManager &, int)> TestCallback
Definition: lstmtrainer.h:73
search
LIST search(LIST list, void *key, int_compare is_equal)
Definition: oldlist.cpp:202
input.h
tesseract::ET_DELTA
Definition: lstmtrainer.h:39
tesseract::TessdataManager::SaveFile
bool SaveFile(const STRING &filename, FileWriter writer) const
Definition: tessdatamanager.cpp:153
UNICHARSET::size
int size() const
Definition: unicharset.h:341
STRING::split
void split(char c, GenericVector< STRING > *splited)
Definition: strngs.cpp:275
tesseract::LSTMTrainer::ComputeWinnerError
double ComputeWinnerError(const NetworkIO &deltas)
Definition: lstmtrainer.cpp:1149
tesseract::RecodeBeamSearch::ExtractBestPathAsLabels
void ExtractBestPathAsLabels(GenericVector< int > *labels, GenericVector< int > *xcoords) const
Definition: recodebeam.cpp:192
tesseract::LSTMTrainer::LogIterations
void LogIterations(const char *intro_str, STRING *log_msg) const
Definition: lstmtrainer.cpp:384
tesseract::LSTMTrainer::ComputeCharError
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
Definition: lstmtrainer.cpp:1167
tesseract::LSTMRecognizer::sample_iteration
int sample_iteration() const
Definition: lstmrecognizer.h:61
tesseract::NOT_BOXED
Definition: lstmtrainer.h:52
tesseract::LSTMRecognizer::DecodeLabels
STRING DecodeLabels(const GenericVector< int > &labels)
Definition: lstmrecognizer.cpp:334