tesseract  5.0.0-alpha-619-ge9db
lstmtrainer.h
Go to the documentation of this file.
1 // File: lstmtrainer.h
3 // Description: Top-level line trainer class for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #ifndef TESSERACT_LSTM_LSTMTRAINER_H_
19 #define TESSERACT_LSTM_LSTMTRAINER_H_
20 
21 #include <functional> // for std::function
22 #include "imagedata.h"
23 #include "lstmrecognizer.h"
24 #include "rect.h"
25 
26 namespace tesseract {
27 
28 class LSTM;
29 class LSTMTester;
30 class LSTMTrainer;
31 class Parallel;
32 class Reversed;
33 class Softmax;
34 class Series;
35 
36 // Enum for the types of errors that are counted.
37 enum ErrorTypes {
38  ET_RMS, // RMS activation error.
39  ET_DELTA, // Number of big errors in deltas.
40  ET_WORD_RECERR, // Output text string word recall error.
41  ET_CHAR_ERROR, // Output text string total char error.
42  ET_SKIP_RATIO, // Fraction of samples skipped.
43  ET_COUNT // For array sizing.
44 };
45 
46 // Enum for the trainability_ flags.
48  TRAINABLE, // Non-zero delta error.
49  PERFECT, // Zero delta error.
50  UNENCODABLE, // Not trainable due to coding/alignment trouble.
51  HI_PRECISION_ERR, // Hi confidence disagreement.
52  NOT_BOXED, // Early in training and has no character boxes.
53 };
54 
55 // Enum to define the amount of data to get serialized.
57  LIGHT, // Minimal data for remote training.
58  NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_.
59  FULL, // All data including best_trainer_.
60 };
61 
62 // Enum to indicate how the sub_trainer_ training went.
64  STR_NONE, // Did nothing as not good enough.
65  STR_UPDATED, // Subtrainer was updated, but didn't replace *this.
66  STR_REPLACED // Subtrainer replaced *this.
67 };
68 
69 class LSTMTrainer;
70 // Function to compute and record error rates on some external test set(s).
71 // Args are: iteration, mean errors, model, training stage.
72 // Returns a STRING containing logging information about the tests.
73 using TestCallback = std::function<STRING(int, const double*, const TessdataManager&, int)>;
74 
75 // Trainer class for LSTM networks. Most of the effort is in creating the
76 // ideal target outputs from the transcription. A box file is used if it is
77 // available, otherwise estimates of the char widths from the unicharset are
78 // used to guide a DP search for the best fit to the transcription.
79 class LSTMTrainer : public LSTMRecognizer {
80  public:
81  LSTMTrainer();
82  LSTMTrainer(const char* model_base, const char* checkpoint_name,
83  int debug_interval, int64_t max_memory);
84  virtual ~LSTMTrainer();
85 
86  // Tries to deserialize a trainer from the given file and silently returns
87  // false in case of failure. If old_traineddata is not null, then it is
88  // assumed that the character set is to be re-mapped from old_traineddata to
89  // the new, with consequent change in weight matrices etc.
90  bool TryLoadingCheckpoint(const char* filename, const char* old_traineddata);
91 
92  // Initializes the character set encode/decode mechanism directly from a
93  // previously setup traineddata containing dawgs, UNICHARSET and
94  // UnicharCompress. Note: Call before InitNetwork!
95  void InitCharSet(const std::string& traineddata_path) {
96  ASSERT_HOST(mgr_.Init(traineddata_path.c_str()));
97  InitCharSet();
98  }
99  void InitCharSet(const TessdataManager& mgr) {
100  mgr_ = mgr;
101  InitCharSet();
102  }
103 
104  // Initializes the trainer with a network_spec in the network description
105  // net_flags control network behavior according to the NetworkFlags enum.
106  // There isn't really much difference between them - only where the effects
107  // are implemented.
108  // For other args see NetworkBuilder::InitNetwork.
109  // Note: Be sure to call InitCharSet before InitNetwork!
110  bool InitNetwork(const STRING& network_spec, int append_index, int net_flags,
111  float weight_range, float learning_rate, float momentum,
112  float adam_beta);
113  // Initializes a trainer from a serialized TFNetworkModel proto.
114  // Returns the global step of TensorFlow graph or 0 if failed.
115  // Building a compatible TF graph: See tfnetwork.proto.
116  int InitTensorFlowNetwork(const std::string& tf_proto);
117  // Resets all the iteration counters for fine tuning or training a head,
118  // where we want the error reporting to reset.
119  void InitIterations();
120 
121  // Accessors.
122  double ActivationError() const {
123  return error_rates_[ET_DELTA];
124  }
125  double CharError() const { return error_rates_[ET_CHAR_ERROR]; }
126  const double* error_rates() const {
127  return error_rates_;
128  }
129  double best_error_rate() const {
130  return best_error_rate_;
131  }
132  int best_iteration() const {
133  return best_iteration_;
134  }
135  int learning_iteration() const { return learning_iteration_; }
136  int32_t improvement_steps() const { return improvement_steps_; }
137  void set_perfect_delay(int delay) { perfect_delay_ = delay; }
138  const GenericVector<char>& best_trainer() const { return best_trainer_; }
139  // Returns the error that was just calculated by PrepareForBackward.
142  }
143  // Returns the error that was just calculated by TrainOnLine. Since
144  // TrainOnLine rolls the error buffers, this is one further back than
145  // NewSingleError.
147  return error_buffers_[type]
150  }
151  const DocumentCache& training_data() const {
152  return training_data_;
153  }
155 
156  // If the training sample is usable, grid searches for the optimal
157  // dict_ratio/cert_offset, and returns the results in a string of space-
158  // separated triplets of ratio,offset=worderr.
160  const ImageData* trainingdata, int iteration, double min_dict_ratio,
161  double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
162  double cert_offset_step, double max_cert_offset, STRING* results);
163 
164  // Provides output on the distribution of weight values.
165  void DebugNetwork();
166 
167  // Loads a set of lstmf files that were created using the lstm.train config to
168  // tesseract into memory ready for training. Returns false if nothing was
169  // loaded.
170  bool LoadAllTrainingData(const GenericVector<STRING>& filenames,
171  CachingStrategy cache_strategy,
172  bool randomly_rotate);
173 
174  // Keeps track of best and locally worst error rate, using internally computed
175  // values. See MaintainCheckpointsSpecific for more detail.
176  bool MaintainCheckpoints(TestCallback tester, STRING* log_msg);
177  // Keeps track of best and locally worst error_rate (whatever it is) and
178  // launches tests using rec_model, when a new min or max is reached.
179  // Writes checkpoints using train_model at appropriate times and builds and
180  // returns a log message to indicate progress. Returns false if nothing
181  // interesting happened.
182  bool MaintainCheckpointsSpecific(int iteration,
183  const GenericVector<char>* train_model,
184  const GenericVector<char>* rec_model,
185  TestCallback tester, STRING* log_msg);
186  // Builds a string containing a progress message with current error rates.
187  void PrepareLogMsg(STRING* log_msg) const;
188  // Appends <intro_str> iteration learning_iteration()/training_iteration()/
189  // sample_iteration() to the log_msg.
190  void LogIterations(const char* intro_str, STRING* log_msg) const;
191 
192  // TODO(rays) Add curriculum learning.
193  // Returns true and increments the training_stage_ if the error rate has just
194  // passed through the given threshold for the first time.
195  bool TransitionTrainingStage(float error_threshold);
196  // Returns the current training stage.
197  int CurrentTrainingStage() const { return training_stage_; }
198 
199  // Writes to the given file. Returns false in case of error.
200  bool Serialize(SerializeAmount serialize_amount,
201  const TessdataManager* mgr, TFile* fp) const;
202  // Reads from the given file. Returns false in case of error.
203  bool DeSerialize(const TessdataManager* mgr, TFile* fp);
204 
205  // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
206  // learning rates (by scaling reduction, or layer specific, according to
207  // NF_LAYER_SPECIFIC_LR).
208  void StartSubtrainer(STRING* log_msg);
209  // While the sub_trainer_ is behind the current training iteration and its
210  // training error is at least kSubTrainerMarginFraction better than the
211  // current training error, trains the sub_trainer_, and returns STR_UPDATED if
212  // it did anything. If it catches up, and has a better error rate than the
213  // current best, as well as a margin over the current error rate, then the
214  // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
215  // returned. STR_NONE is returned if the subtrainer wasn't good enough to
216  // receive any training iterations.
218  // Reduces network learning rates, either for everything, or for layers
219  // independently, according to NF_LAYER_SPECIFIC_LR.
220  void ReduceLearningRates(LSTMTrainer* samples_trainer, STRING* log_msg);
221  // Considers reducing the learning rate independently for each layer down by
222  // factor(<1), or leaving it the same, by double-training the given number of
223  // samples and minimizing the amount of changing of sign of weight updates.
224  // Even if it looks like all weights should remain the same, an adjustment
225  // will be made to guarantee a different result when reverting to an old best.
226  // Returns the number of layer learning rates that were reduced.
227  int ReduceLayerLearningRates(double factor, int num_samples,
228  LSTMTrainer* samples_trainer);
229 
230  // Converts the string to integer class labels, with appropriate null_char_s
231  // in between if not in SimpleTextOutput mode. Returns false on failure.
232  bool EncodeString(const STRING& str, GenericVector<int>* labels) const {
233  return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : nullptr,
234  SimpleTextOutput(), null_char_, labels);
235  }
236  // Static version operates on supplied unicharset, encoder, simple_text.
237  static bool EncodeString(const STRING& str, const UNICHARSET& unicharset,
238  const UnicharCompress* recoder, bool simple_text,
239  int null_char, GenericVector<int>* labels);
240 
241  // Performs forward-backward on the given trainingdata.
242  // Returns the sample that was used or nullptr if the next sample was deemed
243  // unusable. samples_trainer could be this or an alternative trainer that
244  // holds the training samples.
245  const ImageData* TrainOnLine(LSTMTrainer* samples_trainer, bool batch) {
246  int sample_index = sample_iteration();
247  const ImageData* image =
248  samples_trainer->training_data_.GetPageBySerial(sample_index);
249  if (image != nullptr) {
250  Trainability trainable = TrainOnLine(image, batch);
251  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
252  return nullptr; // Sample was unusable.
253  }
254  } else {
256  }
257  return image;
258  }
259  Trainability TrainOnLine(const ImageData* trainingdata, bool batch);
260 
261  // Prepares the ground truth, runs forward, and prepares the targets.
262  // Returns a Trainability enum to indicate the suitability of the sample.
263  Trainability PrepareForBackward(const ImageData* trainingdata,
264  NetworkIO* fwd_outputs, NetworkIO* targets);
265 
266  // Writes the trainer to memory, so that the current training state can be
267  // restored. *this must always be the master trainer that retains the only
268  // copy of the training data and language model. trainer is the model that is
269  // actually serialized.
270  bool SaveTrainingDump(SerializeAmount serialize_amount,
271  const LSTMTrainer* trainer,
272  GenericVector<char>* data) const;
273 
274  // Reads previously saved trainer from memory. *this must always be the
275  // master trainer that retains the only copy of the training data and
276  // language model. trainer is the model that is restored.
278  LSTMTrainer* trainer) const {
279  if (data.empty()) return false;
280  return ReadSizedTrainingDump(&data[0], data.size(), trainer);
281  }
282  bool ReadSizedTrainingDump(const char* data, int size,
283  LSTMTrainer* trainer) const {
284  return trainer->ReadLocalTrainingDump(&mgr_, data, size);
285  }
286  // Restores the model to *this.
287  bool ReadLocalTrainingDump(const TessdataManager* mgr, const char* data,
288  int size);
289 
290  // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
291  void SetupCheckpointInfo();
292 
293  // Writes the full recognition traineddata to the given filename.
294  bool SaveTraineddata(const STRING& filename);
295 
296  // Writes the recognizer to memory, so that it can be used for testing later.
297  void SaveRecognitionDump(GenericVector<char>* data) const;
298 
299  // Returns a suitable filename for a training dump, based on the model_base_,
300  // the iteration and the error rates.
301  STRING DumpFilename() const;
302 
303  // Fills the whole error buffer of the given type with the given value.
304  void FillErrorBuffer(double new_error, ErrorTypes type);
305  // Helper generates a map from each current recoder_ code (ie softmax index)
306  // to the corresponding old_recoder code, or -1 if there isn't one.
307  std::vector<int> MapRecoder(const UNICHARSET& old_chset,
308  const UnicharCompress& old_recoder) const;
309 
310  protected:
311  // Private version of InitCharSet above finishes the job after initializing
312  // the mgr_ data member.
313  void InitCharSet();
314  // Helper computes and sets the null_char_.
315  void SetNullChar();
316 
317  // Factored sub-constructor sets up reasonable default values.
318  void EmptyConstructor();
319 
320  // Outputs the string and periodically displays the given network inputs
321  // as an image in the given window, and the corresponding labels at the
322  // corresponding x_starts.
323  // Returns false if the truth string is empty.
324  bool DebugLSTMTraining(const NetworkIO& inputs,
325  const ImageData& trainingdata,
326  const NetworkIO& fwd_outputs,
327  const GenericVector<int>& truth_labels,
328  const NetworkIO& outputs);
329  // Displays the network targets as line a line graph.
330  void DisplayTargets(const NetworkIO& targets, const char* window_name,
331  ScrollView** window);
332 
333  // Builds a no-compromises target where the first positions should be the
334  // truth labels and the rest is padded with the null_char_.
335  bool ComputeTextTargets(const NetworkIO& outputs,
336  const GenericVector<int>& truth_labels,
337  NetworkIO* targets);
338 
339  // Builds a target using standard CTC. truth_labels should be pre-padded with
340  // nulls wherever desired. They don't have to be between all labels.
341  // outputs is input-output, as it gets clipped to minimum probability.
342  bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
343  NetworkIO* outputs, NetworkIO* targets);
344 
345  // Computes network errors, and stores the results in the rolling buffers,
346  // along with the supplied text_error.
347  // Returns the delta error of the current sample (not running average.)
348  double ComputeErrorRates(const NetworkIO& deltas, double char_error,
349  double word_error);
350 
351  // Computes the network activation RMS error rate.
352  double ComputeRMSError(const NetworkIO& deltas);
353 
354  // Computes network activation winner error rate. (Number of values that are
355  // in error by >= 0.5 divided by number of time-steps.) More closely related
356  // to final character error than RMS, but still directly calculable from
357  // just the deltas. Because of the binary nature of the targets, zero winner
358  // error is a sufficient but not necessary condition for zero char error.
359  double ComputeWinnerError(const NetworkIO& deltas);
360 
361  // Computes a very simple bag of chars char error rate.
362  double ComputeCharError(const GenericVector<int>& truth_str,
363  const GenericVector<int>& ocr_str);
364  // Computes a very simple bag of words word recall error rate.
365  // NOTE that this is destructive on both input strings.
366  double ComputeWordError(STRING* truth_str, STRING* ocr_str);
367 
368  // Updates the error buffer and corresponding mean of the given type with
369  // the new_error.
370  void UpdateErrorBuffer(double new_error, ErrorTypes type);
371 
372  // Rolls error buffers and reports the current means.
373  void RollErrorBuffers();
374 
375  // Given that error_rate is either a new min or max, updates the best/worst
376  // error rates, and record of progress.
377  STRING UpdateErrorGraph(int iteration, double error_rate,
378  const GenericVector<char>& model_data,
379  TestCallback tester);
380 
381  protected:
382  // Alignment display window.
384  // CTC target display window.
386  // CTC output display window.
388  // Reconstructed image window.
390  // How often to display a debug image.
392  // Iteration at which the last checkpoint was dumped.
394  // Basename of files to save best models to.
396  // Checkpoint filename.
398  // Training data.
401  // Name to use when saving best_trainer_.
403  // Number of available training stages.
405 
406  // ===Serialized data to ensure that a restart produces the same results.===
407  // These members are only serialized when serialize_amount != LIGHT.
408  // Best error rate so far.
410  // Snapshot of all error rates at best_iteration_.
412  // Iteration of best_error_rate_.
414  // Worst error rate since best_error_rate_.
416  // Snapshot of all error rates at worst_iteration_.
418  // Iteration of worst_error_rate_.
420  // Iteration at which the process will be thought stalled.
422  // Saved recognition models for computing test error for graph points.
425  // Saved trainer for reverting back to last known best.
427  // A subsidiary trainer running with a different learning rate until either
428  // *this or sub_trainer_ hits a new best.
430  // Error rate at which last best model was dumped.
432  // Current stage of training.
434  // History of best error rate against iteration. Used for computing the
435  // number of steps to each 2% improvement.
438  // Number of iterations since the best_error_rate_ was 2% more than it is now.
440  // Number of iterations that yielded a non-zero delta error and thus provided
441  // significant learning. learning_iteration_ <= training_iteration_.
442  // learning_iteration_ is used to measure rate of learning progress.
444  // Saved value of sample_iteration_ before looking for the the next sample.
446  // How often to include a PERFECT training sample in backprop.
447  // A PERFECT training sample is used if the current
448  // training_iteration_ > last_perfect_training_iteration_ + perfect_delay_,
449  // so with perfect_delay_ == 0, all samples are used, and with
450  // perfect_delay_ == 4, at most 1 in 5 samples will be perfect.
452  // Value of training_iteration_ at which the last PERFECT training sample
453  // was used in back prop.
455  // Rolling buffers storing recent training errors are indexed by
456  // training_iteration % kRollingBufferSize_.
457  static const int kRollingBufferSize_ = 1000;
459  // Rounded mean percent trailing training errors in the buffers.
460  double error_rates_[ET_COUNT]; // RMS training error.
461  // Traineddata file with optional dawgs + UNICHARSET and recoder.
463 };
464 
465 } // namespace tesseract.
466 
467 #endif // TESSERACT_LSTM_LSTMTRAINER_H_
tesseract::ET_COUNT
Definition: lstmtrainer.h:43
tesseract::LSTMTrainer::InitNetwork
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
Definition: lstmtrainer.cpp:146
string
std::string string
Definition: equationdetect_test.cc:21
tesseract::LSTMTrainer::DebugLSTMTraining
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
Definition: lstmtrainer.cpp:1005
tesseract::LSTMTrainer::LoadAllTrainingData
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
Definition: lstmtrainer.cpp:272
ScrollView
Definition: scrollview.h:97
tesseract::STR_NONE
Definition: lstmtrainer.h:64
tesseract::LSTMTrainer::randomly_rotate_
bool randomly_rotate_
Definition: lstmtrainer.h:399
tesseract::LSTMTrainer::InitCharSet
void InitCharSet(const TessdataManager &mgr)
Definition: lstmtrainer.h:99
tesseract::STR_REPLACED
Definition: lstmtrainer.h:66
tesseract::LSTMTrainer::ActivationError
double ActivationError() const
Definition: lstmtrainer.h:122
tesseract::LSTMTrainer::learning_iteration_
int learning_iteration_
Definition: lstmtrainer.h:443
tesseract::LSTMTrainer::PrepareForBackward
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
Definition: lstmtrainer.cpp:770
tesseract::LSTMTrainer::UpdateSubtrainer
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
Definition: lstmtrainer.cpp:519
tesseract::LSTMTrainer::SaveTrainingDump
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
Definition: lstmtrainer.cpp:874
tesseract::LSTMTrainer::SaveRecognitionDump
void SaveRecognitionDump(GenericVector< char > *data) const
Definition: lstmtrainer.cpp:904
tesseract::SerializeAmount
SerializeAmount
Definition: lstmtrainer.h:56
tesseract::TessdataManager
Definition: tessdatamanager.h:126
tesseract::LSTMTrainer::mutable_training_data
DocumentCache * mutable_training_data()
Definition: lstmtrainer.h:154
tesseract::LSTMTrainer::mgr_
TessdataManager mgr_
Definition: lstmtrainer.h:462
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:87
tesseract::LSTMTrainer::ReadSizedTrainingDump
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:282
tesseract::LSTMTrainer::best_error_rates_
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:411
tesseract::LSTMTrainer::num_training_stages_
int num_training_stages_
Definition: lstmtrainer.h:404
tesseract::LSTMTrainer::MapRecoder
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
Definition: lstmtrainer.cpp:933
tesseract::LSTMTrainer::prev_sample_iteration_
int prev_sample_iteration_
Definition: lstmtrainer.h:445
tesseract::LSTMTrainer::best_iteration
int best_iteration() const
Definition: lstmtrainer.h:132
tesseract::LSTMRecognizer::training_iteration
int training_iteration() const
Definition: lstmrecognizer.h:60
tesseract::LSTMTrainer::worst_error_rate_
double worst_error_rate_
Definition: lstmtrainer.h:415
tesseract::LSTMTrainer::checkpoint_name_
STRING checkpoint_name_
Definition: lstmtrainer.h:397
tesseract::LSTMTrainer::SetupCheckpointInfo
void SetupCheckpointInfo()
tesseract::LSTMTrainer::best_error_rate_
double best_error_rate_
Definition: lstmtrainer.h:409
tesseract::LSTMTrainer::DumpFilename
STRING DumpFilename() const
Definition: lstmtrainer.cpp:914
tesseract::LSTMTrainer::debug_interval_
int debug_interval_
Definition: lstmtrainer.h:391
tesseract::LSTMRecognizer::learning_rate
double learning_rate() const
Definition: lstmrecognizer.h:62
tesseract::LSTMTrainer::perfect_delay_
int perfect_delay_
Definition: lstmtrainer.h:451
STRING
Definition: strngs.h:45
tesseract::DocumentCache
Definition: imagedata.h:320
tesseract::LSTMTrainer::error_buffers_
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:458
tesseract::LSTMRecognizer::sample_iteration_
int32_t sample_iteration_
Definition: lstmrecognizer.h:278
tesseract::LSTMTrainer::EncodeString
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:232
tesseract::LSTMTrainer::LastSingleError
double LastSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:146
rect.h
tesseract::LSTMTrainer::DebugNetwork
void DebugNetwork()
Definition: lstmtrainer.cpp:265
tesseract::LSTMRecognizer::GetUnicharset
const UNICHARSET & GetUnicharset() const
Definition: lstmrecognizer.h:132
tesseract::ImageData
Definition: imagedata.h:104
tesseract::LSTMTrainer::model_base_
STRING model_base_
Definition: lstmtrainer.h:395
tesseract::LSTMTrainer::TrainOnLine
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:245
tesseract::LSTMTrainer::best_error_history_
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:436
tesseract::ET_WORD_RECERR
Definition: lstmtrainer.h:40
tesseract::PERFECT
Definition: lstmtrainer.h:49
tesseract::LSTMTrainer::best_trainer_
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:426
tesseract::LSTMTrainer::FillErrorBuffer
void FillErrorBuffer(double new_error, ErrorTypes type)
Definition: lstmtrainer.cpp:925
tesseract::LSTMTrainer::StartSubtrainer
void StartSubtrainer(STRING *log_msg)
Definition: lstmtrainer.cpp:489
tesseract::ErrorTypes
ErrorTypes
Definition: lstmtrainer.h:37
tesseract::FULL
Definition: lstmtrainer.h:59
tesseract::LSTMTrainer::error_rate_of_last_saved_best_
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:431
tesseract::LSTMTrainer::Serialize
bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
Definition: lstmtrainer.cpp:403
tesseract::LSTMTrainer::NewSingleError
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:140
tesseract::DocumentCache::GetPageBySerial
const ImageData * GetPageBySerial(int serial)
Definition: imagedata.h:343
tesseract::LSTMTrainer::RollErrorBuffers
void RollErrorBuffers()
Definition: lstmtrainer.cpp:1241
tesseract::LSTMTrainer::align_win_
ScrollView * align_win_
Definition: lstmtrainer.h:383
tesseract::LSTMTrainer::ReadLocalTrainingDump
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
Definition: lstmtrainer.cpp:883
tesseract::LSTMRecognizer::null_char
int null_char() const
Definition: lstmrecognizer.h:145
tesseract::LSTMTrainer::CharError
double CharError() const
Definition: lstmtrainer.h:125
tesseract::LSTMTrainer::best_model_data_
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:423
tesseract::LSTMTrainer::UpdateErrorBuffer
void UpdateErrorBuffer(double new_error, ErrorTypes type)
Definition: lstmtrainer.cpp:1228
tesseract::LSTMTrainer::SetNullChar
void SetNullChar()
Definition: lstmtrainer.cpp:981
tesseract::LSTMTrainer::learning_iteration
int learning_iteration() const
Definition: lstmtrainer.h:135
tesseract::LSTMTrainer::ComputeRMSError
double ComputeRMSError(const NetworkIO &deltas)
Definition: lstmtrainer.cpp:1130
tesseract::TFile
Definition: serialis.h:75
tesseract::LSTMTrainer::ComputeErrorRates
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
Definition: lstmtrainer.cpp:1110
tesseract::STR_UPDATED
Definition: lstmtrainer.h:65
GenericVector::empty
bool empty() const
Definition: genericvector.h:86
UNICHARSET
Definition: unicharset.h:145
tesseract::LSTMTrainer::improvement_steps_
int32_t improvement_steps_
Definition: lstmtrainer.h:439
tesseract::NetworkIO
Definition: networkio.h:39
tesseract::LSTMTrainer::best_model_name_
STRING best_model_name_
Definition: lstmtrainer.h:402
tesseract::LSTMTrainer::training_data
const DocumentCache & training_data() const
Definition: lstmtrainer.h:151
tesseract::LSTMTrainer::ctc_win_
ScrollView * ctc_win_
Definition: lstmtrainer.h:387
tesseract::LSTMTrainer::improvement_steps
int32_t improvement_steps() const
Definition: lstmtrainer.h:136
tesseract::LSTMRecognizer::recoder_
UnicharCompress recoder_
Definition: lstmrecognizer.h:268
lstmrecognizer.h
tesseract::Trainability
Trainability
Definition: lstmtrainer.h:47
tesseract::LSTMTrainer::worst_iteration_
int worst_iteration_
Definition: lstmtrainer.h:419
tesseract
Definition: baseapi.h:65
tesseract::LSTMTrainer::TryLoadingCheckpoint
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
Definition: lstmtrainer.cpp:103
tesseract::LSTMTrainer::ComputeWordError
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
Definition: lstmtrainer.cpp:1195
tesseract::LSTMTrainer::set_perfect_delay
void set_perfect_delay(int delay)
Definition: lstmtrainer.h:137
tesseract::LSTMTrainer::target_win_
ScrollView * target_win_
Definition: lstmtrainer.h:385
tesseract::LSTMTrainer::stall_iteration_
int stall_iteration_
Definition: lstmtrainer.h:421
tesseract::LSTMTrainer::ReadTrainingDump
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Definition: lstmtrainer.h:277
tesseract::LSTMTrainer::training_stage_
int training_stage_
Definition: lstmtrainer.h:433
tesseract::SubTrainerResult
SubTrainerResult
Definition: lstmtrainer.h:63
tesseract::LSTMTrainer::GridSearchDictParams
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
Definition: lstmtrainer.cpp:215
tesseract::LSTMRecognizer::SimpleTextOutput
bool SimpleTextOutput() const
Definition: lstmrecognizer.h:69
tesseract::LSTMTrainer::DisplayTargets
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
Definition: lstmtrainer.cpp:1042
tesseract::LSTMRecognizer::null_char_
int32_t null_char_
Definition: lstmrecognizer.h:281
tesseract::ET_CHAR_ERROR
Definition: lstmtrainer.h:41
tesseract::LSTMTrainer::best_iteration_
int best_iteration_
Definition: lstmtrainer.h:413
tesseract::LSTMTrainer::error_rates
const double * error_rates() const
Definition: lstmtrainer.h:126
GenericVector< char >
tesseract::LSTMTrainer::InitTensorFlowNetwork
int InitTensorFlowNetwork(const std::string &tf_proto)
tesseract::LSTMTrainer::sub_trainer_
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:429
tesseract::LSTMTrainer::DeSerialize
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
Definition: lstmtrainer.cpp:440
tesseract::HI_PRECISION_ERR
Definition: lstmtrainer.h:51
tesseract::ET_RMS
Definition: lstmtrainer.h:38
tesseract::LSTMTrainer::EmptyConstructor
void EmptyConstructor()
Definition: lstmtrainer.cpp:990
tesseract::LSTMTrainer::InitCharSet
void InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:95
tesseract::LSTMTrainer
Definition: lstmtrainer.h:79
imagedata.h
tesseract::LSTMTrainer::best_error_rate
double best_error_rate() const
Definition: lstmtrainer.h:129
tesseract::LSTMTrainer::ComputeCTCTargets
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
Definition: lstmtrainer.cpp:1099
tesseract::LSTMTrainer::recon_win_
ScrollView * recon_win_
Definition: lstmtrainer.h:389
tesseract::LSTMRecognizer
Definition: lstmrecognizer.h:53
tesseract::LSTMTrainer::InitCharSet
void InitCharSet()
Definition: lstmtrainer.cpp:968
tesseract::TessdataManager::Init
bool Init(const char *data_file_name)
Definition: tessdatamanager.cpp:97
tesseract::LSTMRecognizer::IsRecoding
bool IsRecoding() const
Definition: lstmrecognizer.h:72
tesseract::LSTMTrainer::ReduceLearningRates
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
Definition: lstmtrainer.cpp:562
tesseract::CachingStrategy
CachingStrategy
Definition: imagedata.h:41
tesseract::LSTMTrainer::last_perfect_training_iteration_
int last_perfect_training_iteration_
Definition: lstmtrainer.h:454
tesseract::LSTMTrainer::best_trainer
const GenericVector< char > & best_trainer() const
Definition: lstmtrainer.h:138
tesseract::LSTMTrainer::CurrentTrainingStage
int CurrentTrainingStage() const
Definition: lstmtrainer.h:197
tesseract::LSTMTrainer::ReduceLayerLearningRates
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
Definition: lstmtrainer.cpp:581
tesseract::LSTMTrainer::best_error_iterations_
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:437
tesstrain_utils.type
type
Definition: tesstrain_utils.py:141
tesseract::LSTMTrainer::UpdateErrorGraph
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
Definition: lstmtrainer.cpp:1260
tesseract::LSTMTrainer::worst_model_data_
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:424
tesseract::ET_SKIP_RATIO
Definition: lstmtrainer.h:42
tesseract::LIGHT
Definition: lstmtrainer.h:57
tesseract::LSTMTrainer::worst_error_rates_
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:417
tesseract::LSTMTrainer::kRollingBufferSize_
static const int kRollingBufferSize_
Definition: lstmtrainer.h:457
tesseract::LSTMTrainer::error_rates_
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:460
tesseract::LSTMTrainer::~LSTMTrainer
virtual ~LSTMTrainer()
Definition: lstmtrainer.cpp:93
tesseract::LSTMTrainer::MaintainCheckpoints
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
Definition: lstmtrainer.cpp:285
tesseract::LSTMTrainer::training_data_
DocumentCache training_data_
Definition: lstmtrainer.h:400
tesseract::LSTMTrainer::TransitionTrainingStage
bool TransitionTrainingStage(float error_threshold)
Definition: lstmtrainer.cpp:393
tesseract::LSTMTrainer::MaintainCheckpointsSpecific
bool MaintainCheckpointsSpecific(int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg)
tesseract::LSTMTrainer::PrepareLogMsg
void PrepareLogMsg(STRING *log_msg) const
Definition: lstmtrainer.cpp:372
tesseract::UNENCODABLE
Definition: lstmtrainer.h:50
tesseract::LSTMTrainer::ComputeTextTargets
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
Definition: lstmtrainer.cpp:1079
tesseract::UnicharCompress
Definition: unicharcompress.h:128
GenericVector::size
int size() const
Definition: genericvector.h:71
tesseract::LSTMTrainer::LSTMTrainer
LSTMTrainer()
Definition: lstmtrainer.cpp:74
tesseract::LSTMTrainer::SaveTraineddata
bool SaveTraineddata(const STRING &filename)
Definition: lstmtrainer.cpp:895
tesseract::LSTMTrainer::InitIterations
void InitIterations()
Definition: lstmtrainer.cpp:190
tesseract::LSTMTrainer::checkpoint_iteration_
int checkpoint_iteration_
Definition: lstmtrainer.h:393
tesseract::NO_BEST_TRAINER
Definition: lstmtrainer.h:58
tesseract::TRAINABLE
Definition: lstmtrainer.h:48
tesseract::TestCallback
std::function< STRING(int, const double *, const TessdataManager &, int)> TestCallback
Definition: lstmtrainer.h:73
tesseract::ET_DELTA
Definition: lstmtrainer.h:39
tesseract::LSTMTrainer::ComputeWinnerError
double ComputeWinnerError(const NetworkIO &deltas)
Definition: lstmtrainer.cpp:1149
tesseract::LSTMTrainer::LogIterations
void LogIterations(const char *intro_str, STRING *log_msg) const
Definition: lstmtrainer.cpp:384
tesseract::LSTMTrainer::ComputeCharError
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
Definition: lstmtrainer.cpp:1167
tesseract::LSTMRecognizer::sample_iteration
int sample_iteration() const
Definition: lstmrecognizer.h:61
tesseract::NOT_BOXED
Definition: lstmtrainer.h:52