19 #ifdef GOOGLE_TESSERACT 20 #include "base/commandlineflags.h" 31 INT_PARAM_FLAG(debug_interval, 0,
"How often to display the alignment.");
35 "How many imperfect samples between perfect ones.");
41 INT_PARAM_FLAG(max_image_MB, 6000,
"Max memory to use for images.");
45 "File listing training files in lstmf training format.");
47 "File listing eval files in lstmf training format.");
49 "Just convert the training model to a runtime model.");
51 "Convert the recognition model to an integer model.");
53 "Use the training files sequentially instead of round-robin.");
54 INT_PARAM_FLAG(append_index, -1,
"Index in continue_from Network at which to" 55 " attach the new network defined by net_spec");
57 "Get info on distribution of weight values");
58 INT_PARAM_FLAG(max_iterations, 0,
"If set, exit after this many iterations");
60 "Combined Dawgs/Unicharset/Recoder for language model");
62 "When changing the character set, this specifies the old" 63 " character set that is to be replaced");
65 "Train OSD and randomly turn training samples upside-down");
74 int main(
int argc,
char **argv) {
75 tesseract::CheckSharedLibraryVersion();
77 if (FLAGS_model_output.empty()) {
78 tprintf(
"Must provide a --model_output!\n");
81 if (FLAGS_traineddata.empty()) {
82 tprintf(
"Must provide a --traineddata see training wiki\n");
88 test_file +=
"_wtest";
89 FILE* f = fopen(test_file.
c_str(),
"wb");
92 if (
remove(test_file.
c_str()) != 0) {
93 tprintf(
"Error, failed to remove %s: %s\n",
94 test_file.
c_str(), strerror(errno));
98 tprintf(
"Error, model output cannot be written: %s\n", strerror(errno));
103 STRING checkpoint_file = FLAGS_model_output.
c_str();
104 checkpoint_file +=
"_checkpoint";
105 STRING checkpoint_bak = checkpoint_file +
".bak";
107 nullptr,
nullptr,
nullptr,
nullptr, FLAGS_model_output.c_str(),
108 checkpoint_file.
c_str(), FLAGS_debug_interval,
109 static_cast<int64_t
>(FLAGS_max_image_MB) * 1048576);
114 if (FLAGS_stop_training || FLAGS_debug_network) {
116 tprintf(
"Failed to read continue from: %s\n",
117 FLAGS_continue_from.c_str());
120 if (FLAGS_debug_network) {
125 tprintf(
"Failed to write recognition model : %s\n",
126 FLAGS_model_output.c_str());
133 if (FLAGS_train_listfile.empty()) {
134 tprintf(
"Must supply a list of training filenames! --train_listfile\n");
140 tprintf(
"Failed to load list of training filenames from %s\n",
141 FLAGS_train_listfile.c_str());
148 tprintf(
"Successfully restored trainer from %s\n",
149 checkpoint_file.
string());
151 if (!FLAGS_continue_from.empty()) {
154 FLAGS_append_index >= 0
155 ? FLAGS_continue_from.c_str()
156 : FLAGS_old_traineddata.c_str())) {
157 tprintf(
"Failed to continue from: %s\n", FLAGS_continue_from.c_str());
160 tprintf(
"Continuing from %s\n", FLAGS_continue_from.c_str());
163 if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
164 if (FLAGS_append_index >= 0) {
165 tprintf(
"Appending a new network to an old one!!");
166 if (FLAGS_continue_from.empty()) {
167 tprintf(
"Must set --continue_from for appending!\n");
172 if (!trainer.
InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
173 FLAGS_net_mode, FLAGS_weight_range,
174 FLAGS_learning_rate, FLAGS_momentum,
176 tprintf(
"Failed to create network from spec: %s\n",
177 FLAGS_net_spec.c_str());
184 FLAGS_sequential_training
187 FLAGS_randomly_rotate)) {
188 tprintf(
"Load of images failed!!\n");
195 if (!FLAGS_eval_listfile.empty()) {
197 tprintf(
"Failed to load eval data from: %s\n",
198 FLAGS_eval_listfile.c_str());
208 iteration < target_iteration &&
209 (iteration < FLAGS_max_iterations || FLAGS_max_iterations == 0);
218 FLAGS_max_iterations == 0));
219 delete tester_callback;
void InitCharSet(const std::string &traineddata_path)
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
bool SaveTraineddata(const STRING &filename)
const char * string() const
void ParseArguments(int *argc, char ***argv)
int training_iteration() const
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
bool LoadAllEvalData(const STRING &filenames_file)
_ConstTessMemberResultCallback_0_0< false, R, T1 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)() const)
const char * c_str() const
const int kNumPagesPerBatch
void set_perfect_delay(int delay)
DLLSYM void tprintf(const char *format,...)
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.")
STRING RunEvalAsync(int iteration, const double *training_errors, const TessdataManager &model_mgr, int training_stage)
double best_error_rate() const
bool LoadFileLinesToStrings(const STRING &filename, GenericVector< STRING > *lines)
BOOL_PARAM_FLAG(stop_training, false, "Just convert the training model to a runtime model.")
STRING_PARAM_FLAG(net_spec, "", "Network specification")
INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment.")
int main(int argc, char **argv)
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)