12 #ifndef TESSERACT_UNITTEST_LSTM_TEST_H_
13 #define TESSERACT_UNITTEST_LSTM_TEST_H_
21 #include "absl/strings/str_cat.h"
49 std::locale::global(std::locale(
""));
67 bool recode,
bool adam) {
68 SetupTrainer(network_spec, model_name,
"eng/eng.unicharset",
69 "eng.Arial.exp0.lstmf", recode, adam, 5e-4,
false,
"eng");
73 bool recode,
bool adam,
double learning_rate,
78 ASSERT_TRUE(unicharset.
load_from_file(unicharset_name.c_str(),
false));
83 kLang, !recode, words, words, words,
false,
86 std::string checkpoint_path = model_path +
"_checkpoint";
90 absl::StrCat(kLang,
".traineddata")));
91 int net_mode = adam ?
NF_ADAM : 0;
94 if (adam) learning_rate *= 20.0;
96 EXPECT_TRUE(
trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1,
97 learning_rate, 0.9, 0.999));
101 LOG(
INFO) <<
"Setup network:" << model_name <<
"\n" ;
105 int iteration =
trainer_->training_iteration();
106 int iteration_limit = iteration + max_iterations;
107 double best_error = 100.0;
112 double mean_error = 0.0;
113 while (iteration < target_iteration && iteration < iteration_limit) {
115 iteration =
trainer_->training_iteration();
118 trainer_->MaintainCheckpoints(
nullptr, &log_str);
119 iteration =
trainer_->training_iteration();
122 LOG(
INFO) <<
"Best error = " << best_error <<
"\n" ;
123 LOG(
INFO) <<
"Mean error = " << mean_error <<
"\n" ;
124 if (mean_error < best_error) best_error = mean_error;
125 }
while (iteration < iteration_limit);
126 LOG(
INFO) <<
"Trainer error rate = " << best_error <<
"\n";
132 int iteration =
trainer_->sample_iteration();
133 double mean_error = 0.0;
135 while (error_count < max_iterations) {
137 *
trainer_->mutable_training_data()->GetPageBySerial(iteration);
139 if (
trainer_->PrepareForBackward(&trainingdata, &fwd_outputs, &targets) !=
144 trainer_->SetIteration(++iteration);
146 mean_error *= 100.0 / max_iterations;
147 LOG(
INFO) <<
"Tester error rate = " << mean_error <<
"\n" ;
163 EXPECT_LT(int_err, float_err + 1.0);
164 return int_err - float_err;
170 std::string unicharset_name = lang +
"/" + lang +
".unicharset";
171 std::string lstmf_name = lang +
".Arial_Unicode_MS.exp0.lstmf";
172 SetupTrainer(
"[1,1,0,32 Lbx100 O1c1]",
"bidi-lstm", unicharset_name,
173 lstmf_name, recode,
true, 5e-4,
true, lang);
175 EXPECT_TRUE(
trainer_->EncodeString(str.c_str(), &labels));
178 EXPECT_EQ(str, decoded_str);
191 #endif // THIRD_PARTY_TESSERACT_UNITTEST_LSTM_TEST_H_