tesseract  5.0.0-alpha-619-ge9db
lstmtraining.cpp
Go to the documentation of this file.
1 // File: lstmtraining.cpp
3 // Description: Training program for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #ifdef GOOGLE_TESSERACT
19 #include "base/commandlineflags.h"
20 #endif
21 #include <cerrno>
22 #include "commontraining.h"
23 #include "fileio.h" // for LoadFileLinesToStrings
24 #include "lstmtester.h"
25 #include "lstmtrainer.h"
26 #include "params.h"
27 #include <tesseract/strngs.h>
28 #include "tprintf.h"
30 
31 static INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment.");
32 static STRING_PARAM_FLAG(net_spec, "", "Network specification");
33 static INT_PARAM_FLAG(net_mode, 192, "Controls network behavior.");
34 static INT_PARAM_FLAG(perfect_sample_delay, 0,
35  "How many imperfect samples between perfect ones.");
36 static DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.");
37 static DOUBLE_PARAM_FLAG(weight_range, 0.1, "Range of initial random weights.");
38 static DOUBLE_PARAM_FLAG(learning_rate, 10.0e-4, "Weight factor for new deltas.");
39 static DOUBLE_PARAM_FLAG(momentum, 0.5, "Decay factor for repeating deltas.");
40 static DOUBLE_PARAM_FLAG(adam_beta, 0.999, "Decay factor for repeating deltas.");
41 static INT_PARAM_FLAG(max_image_MB, 6000, "Max memory to use for images.");
42 static STRING_PARAM_FLAG(continue_from, "", "Existing model to extend");
43 static STRING_PARAM_FLAG(model_output, "lstmtrain", "Basename for output models");
44 static STRING_PARAM_FLAG(train_listfile, "",
45  "File listing training files in lstmf training format.");
46 static STRING_PARAM_FLAG(eval_listfile, "",
47  "File listing eval files in lstmf training format.");
48 static BOOL_PARAM_FLAG(stop_training, false,
49  "Just convert the training model to a runtime model.");
50 static BOOL_PARAM_FLAG(convert_to_int, false,
51  "Convert the recognition model to an integer model.");
52 static BOOL_PARAM_FLAG(sequential_training, false,
53  "Use the training files sequentially instead of round-robin.");
54 static INT_PARAM_FLAG(append_index, -1, "Index in continue_from Network at which to"
55  " attach the new network defined by net_spec");
56 static BOOL_PARAM_FLAG(debug_network, false,
57  "Get info on distribution of weight values");
58 static INT_PARAM_FLAG(max_iterations, 0, "If set, exit after this many iterations");
59 static STRING_PARAM_FLAG(traineddata, "",
60  "Combined Dawgs/Unicharset/Recoder for language model");
61 static STRING_PARAM_FLAG(old_traineddata, "",
62  "When changing the character set, this specifies the old"
63  " character set that is to be replaced");
64 static BOOL_PARAM_FLAG(randomly_rotate, false,
65  "Train OSD and randomly turn training samples upside-down");
66 
67 // Number of training images to train between calls to MaintainCheckpoints.
68 const int kNumPagesPerBatch = 100;
69 
70 // Apart from command-line flags, input is a collection of lstmf files, that
71 // were previously created using tesseract with the lstm.train config file.
72 // The program iterates over the inputs, feeding the data to the network,
73 // until the error rate reaches a specified target or max_iterations is reached.
74 int main(int argc, char **argv) {
75  tesseract::CheckSharedLibraryVersion();
76  ParseArguments(&argc, &argv);
77  if (FLAGS_model_output.empty()) {
78  tprintf("Must provide a --model_output!\n");
79  return EXIT_FAILURE;
80  }
81  if (FLAGS_traineddata.empty()) {
82  tprintf("Must provide a --traineddata see training wiki\n");
83  return EXIT_FAILURE;
84  }
85 
86  // Check write permissions.
87  STRING test_file = FLAGS_model_output.c_str();
88  test_file += "_wtest";
89  FILE* f = fopen(test_file.c_str(), "wb");
90  if (f != nullptr) {
91  fclose(f);
92  if (remove(test_file.c_str()) != 0) {
93  tprintf("Error, failed to remove %s: %s\n",
94  test_file.c_str(), strerror(errno));
95  return EXIT_FAILURE;
96  }
97  } else {
98  tprintf("Error, model output cannot be written: %s\n", strerror(errno));
99  return EXIT_FAILURE;
100  }
101 
102  // Setup the trainer.
103  STRING checkpoint_file = FLAGS_model_output.c_str();
104  checkpoint_file += "_checkpoint";
105  STRING checkpoint_bak = checkpoint_file + ".bak";
106  tesseract::LSTMTrainer trainer(
107  FLAGS_model_output.c_str(),
108  checkpoint_file.c_str(), FLAGS_debug_interval,
109  static_cast<int64_t>(FLAGS_max_image_MB) * 1048576);
110  trainer.InitCharSet(FLAGS_traineddata.c_str());
111 
112  // Reading something from an existing model doesn't require many flags,
113  // so do it now and exit.
114  if (FLAGS_stop_training || FLAGS_debug_network) {
115  if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(), nullptr)) {
116  tprintf("Failed to read continue from: %s\n",
117  FLAGS_continue_from.c_str());
118  return EXIT_FAILURE;
119  }
120  if (FLAGS_debug_network) {
121  trainer.DebugNetwork();
122  } else {
123  if (FLAGS_convert_to_int) trainer.ConvertToInt();
124  if (!trainer.SaveTraineddata(FLAGS_model_output.c_str())) {
125  tprintf("Failed to write recognition model : %s\n",
126  FLAGS_model_output.c_str());
127  }
128  }
129  return EXIT_SUCCESS;
130  }
131 
132  // Get the list of files to process.
133  if (FLAGS_train_listfile.empty()) {
134  tprintf("Must supply a list of training filenames! --train_listfile\n");
135  return EXIT_FAILURE;
136  }
137  GenericVector<STRING> filenames;
138  if (!tesseract::LoadFileLinesToStrings(FLAGS_train_listfile.c_str(),
139  &filenames)) {
140  tprintf("Failed to load list of training filenames from %s\n",
141  FLAGS_train_listfile.c_str());
142  return EXIT_FAILURE;
143  }
144 
145  // Checkpoints always take priority if they are available.
146  if (trainer.TryLoadingCheckpoint(checkpoint_file.c_str(), nullptr) ||
147  trainer.TryLoadingCheckpoint(checkpoint_bak.c_str(), nullptr)) {
148  tprintf("Successfully restored trainer from %s\n",
149  checkpoint_file.c_str());
150  } else {
151  if (!FLAGS_continue_from.empty()) {
152  // Load a past model file to improve upon.
153  if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
154  FLAGS_append_index >= 0
155  ? FLAGS_continue_from.c_str()
156  : FLAGS_old_traineddata.c_str())) {
157  tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str());
158  return EXIT_FAILURE;
159  }
160  tprintf("Continuing from %s\n", FLAGS_continue_from.c_str());
161  trainer.InitIterations();
162  }
163  if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
164  if (FLAGS_append_index >= 0) {
165  tprintf("Appending a new network to an old one!!");
166  if (FLAGS_continue_from.empty()) {
167  tprintf("Must set --continue_from for appending!\n");
168  return EXIT_FAILURE;
169  }
170  }
171  // We are initializing from scratch.
172  if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
173  FLAGS_net_mode, FLAGS_weight_range,
174  FLAGS_learning_rate, FLAGS_momentum,
175  FLAGS_adam_beta)) {
176  tprintf("Failed to create network from spec: %s\n",
177  FLAGS_net_spec.c_str());
178  return EXIT_FAILURE;
179  }
180  trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
181  }
182  }
183  if (!trainer.LoadAllTrainingData(filenames,
184  FLAGS_sequential_training
187  FLAGS_randomly_rotate)) {
188  tprintf("Load of images failed!!\n");
189  return EXIT_FAILURE;
190  }
191 
192  tesseract::LSTMTester tester(static_cast<int64_t>(FLAGS_max_image_MB) *
193  1048576);
194  tesseract::TestCallback tester_callback = nullptr;
195  if (!FLAGS_eval_listfile.empty()) {
196  using namespace std::placeholders; // for _1, _2, _3...
197  if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
198  tprintf("Failed to load eval data from: %s\n",
199  FLAGS_eval_listfile.c_str());
200  return EXIT_FAILURE;
201  }
202  tester_callback = std::bind(&tesseract::LSTMTester::RunEvalAsync, &tester, _1, _2, _3, _4);
203  }
204  do {
205  // Train a few.
206  int iteration = trainer.training_iteration();
207  for (int target_iteration = iteration + kNumPagesPerBatch;
208  iteration < target_iteration &&
209  (iteration < FLAGS_max_iterations || FLAGS_max_iterations == 0);
210  iteration = trainer.training_iteration()) {
211  trainer.TrainOnLine(&trainer, false);
212  }
213  STRING log_str;
214  trainer.MaintainCheckpoints(tester_callback, &log_str);
215  tprintf("%s\n", log_str.c_str());
216  } while (trainer.best_error_rate() > FLAGS_target_error_rate &&
217  (trainer.training_iteration() < FLAGS_max_iterations ||
218  FLAGS_max_iterations == 0));
219  tprintf("Finished! Error rate = %g\n", trainer.best_error_rate());
220  return EXIT_SUCCESS;
221 } /* main */
tesseract::LSTMTester
Definition: lstmtester.h:28
strngs.h
BOOL_PARAM_FLAG
#define BOOL_PARAM_FLAG(name, val, comment)
Definition: commandlineflags.h:33
unicharset_training_utils.h
INT_PARAM_FLAG
#define INT_PARAM_FLAG(name, val, comment)
Definition: commandlineflags.h:25
commontraining.h
DOUBLE_PARAM_FLAG
#define DOUBLE_PARAM_FLAG(name, val, comment)
Definition: commandlineflags.h:29
kNumPagesPerBatch
const int kNumPagesPerBatch
Definition: lstmtraining.cpp:68
STRING_PARAM_FLAG
#define STRING_PARAM_FLAG(name, val, comment)
Definition: commandlineflags.h:37
lstmtrainer.h
tesseract::CS_ROUND_ROBIN
Definition: imagedata.h:53
params.h
STRING
Definition: strngs.h:45
lstmtester.h
tesseract::CS_SEQUENTIAL
Definition: imagedata.h:48
fileio.h
main
int main(int argc, char **argv)
Definition: lstmtraining.cpp:74
STRING::c_str
const char * c_str() const
Definition: strngs.cpp:192
tesseract::LoadFileLinesToStrings
bool LoadFileLinesToStrings(const char *filename, GenericVector< STRING > *lines)
Definition: fileio.h:43
tesseract::LSTMTester::RunEvalAsync
STRING RunEvalAsync(int iteration, const double *training_errors, const TessdataManager &model_mgr, int training_stage)
Definition: lstmtester.cpp:53
tprintf.h
GenericVector< STRING >
tesseract::LSTMTrainer::InitCharSet
void InitCharSet(const std::string &traineddata_path)
Definition: lstmtrainer.h:95
tesseract::LSTMTrainer
Definition: lstmtrainer.h:79
ParseArguments
void ParseArguments(int *argc, char ***argv)
Definition: commontraining.cpp:122
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
tesseract::LSTMTester::LoadAllEvalData
bool LoadAllEvalData(const STRING &filenames_file)
Definition: lstmtester.cpp:31
tesseract::TestCallback
std::function< STRING(int, const double *, const TessdataManager &, int)> TestCallback
Definition: lstmtrainer.h:73