75 tesseract::CheckSharedLibraryVersion();
77 if (FLAGS_model_output.empty()) {
78 tprintf(
"Must provide a --model_output!\n");
81 if (FLAGS_traineddata.empty()) {
82 tprintf(
"Must provide a --traineddata see training wiki\n");
88 test_file +=
"_wtest";
89 FILE* f = fopen(test_file.
c_str(),
"wb");
92 if (remove(test_file.
c_str()) != 0) {
93 tprintf(
"Error, failed to remove %s: %s\n",
94 test_file.
c_str(), strerror(errno));
98 tprintf(
"Error, model output cannot be written: %s\n", strerror(errno));
103 STRING checkpoint_file = FLAGS_model_output.
c_str();
104 checkpoint_file +=
"_checkpoint";
105 STRING checkpoint_bak = checkpoint_file +
".bak";
107 FLAGS_model_output.c_str(),
108 checkpoint_file.
c_str(), FLAGS_debug_interval,
109 static_cast<int64_t>(FLAGS_max_image_MB) * 1048576);
114 if (FLAGS_stop_training || FLAGS_debug_network) {
115 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
nullptr)) {
116 tprintf(
"Failed to read continue from: %s\n",
117 FLAGS_continue_from.c_str());
120 if (FLAGS_debug_network) {
121 trainer.DebugNetwork();
123 if (FLAGS_convert_to_int) trainer.ConvertToInt();
124 if (!trainer.SaveTraineddata(FLAGS_model_output.c_str())) {
125 tprintf(
"Failed to write recognition model : %s\n",
126 FLAGS_model_output.c_str());
133 if (FLAGS_train_listfile.empty()) {
134 tprintf(
"Must supply a list of training filenames! --train_listfile\n");
140 tprintf(
"Failed to load list of training filenames from %s\n",
141 FLAGS_train_listfile.c_str());
146 if (trainer.TryLoadingCheckpoint(checkpoint_file.
c_str(),
nullptr) ||
147 trainer.TryLoadingCheckpoint(checkpoint_bak.
c_str(),
nullptr)) {
148 tprintf(
"Successfully restored trainer from %s\n",
149 checkpoint_file.
c_str());
151 if (!FLAGS_continue_from.empty()) {
153 if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str(),
154 FLAGS_append_index >= 0
155 ? FLAGS_continue_from.c_str()
156 : FLAGS_old_traineddata.c_str())) {
157 tprintf(
"Failed to continue from: %s\n", FLAGS_continue_from.c_str());
160 tprintf(
"Continuing from %s\n", FLAGS_continue_from.c_str());
161 trainer.InitIterations();
163 if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
164 if (FLAGS_append_index >= 0) {
165 tprintf(
"Appending a new network to an old one!!");
166 if (FLAGS_continue_from.empty()) {
167 tprintf(
"Must set --continue_from for appending!\n");
172 if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
173 FLAGS_net_mode, FLAGS_weight_range,
174 FLAGS_learning_rate, FLAGS_momentum,
176 tprintf(
"Failed to create network from spec: %s\n",
177 FLAGS_net_spec.c_str());
180 trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
183 if (!trainer.LoadAllTrainingData(filenames,
184 FLAGS_sequential_training
187 FLAGS_randomly_rotate)) {
188 tprintf(
"Load of images failed!!\n");
195 if (!FLAGS_eval_listfile.empty()) {
196 using namespace std::placeholders;
197 if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
198 tprintf(
"Failed to load eval data from: %s\n",
199 FLAGS_eval_listfile.c_str());
206 int iteration = trainer.training_iteration();
208 iteration < target_iteration &&
209 (iteration < FLAGS_max_iterations || FLAGS_max_iterations == 0);
210 iteration = trainer.training_iteration()) {
211 trainer.TrainOnLine(&trainer,
false);
214 trainer.MaintainCheckpoints(tester_callback, &log_str);
216 }
while (trainer.best_error_rate() > FLAGS_target_error_rate &&
217 (trainer.training_iteration() < FLAGS_max_iterations ||
218 FLAGS_max_iterations == 0));
219 tprintf(
"Finished! Error rate = %g\n", trainer.best_error_rate());