#include <lstm_test.h>
|
void | SetUp () |
|
| LSTMTrainerTest () |
|
std::string | TestDataNameToPath (const std::string &name) |
|
std::string | TessDataNameToPath (const std::string &name) |
|
std::string | TestingNameToPath (const std::string &name) |
|
void | SetupTrainerEng (const std::string &network_spec, const std::string &model_name, bool recode, bool adam) |
|
void | SetupTrainer (const std::string &network_spec, const std::string &model_name, const std::string &unicharset_file, const std::string &lstmf_file, bool recode, bool adam, double learning_rate, bool layer_specific, const std::string &kLang) |
|
double | TrainIterations (int max_iterations) |
|
double | TestIterations (int max_iterations) |
|
double | TestIntMode (int test_iterations) |
|
void | TestEncodeDecode (const std::string &lang, const std::string &str, bool recode) |
|
void | TestEncodeDecodeBoth (const std::string &lang, const std::string &str) |
|
Definition at line 46 of file lstm_test.h.
◆ LSTMTrainerTest()
tesseract::LSTMTrainerTest::LSTMTrainerTest |
( |
| ) |
|
|
inlineprotected |
◆ SetUp()
void tesseract::LSTMTrainerTest::SetUp |
( |
| ) |
|
|
inlineprotected |
Definition at line 48 of file lstm_test.h.
49 std::locale::global(std::locale(
""));
◆ SetupTrainer()
void tesseract::LSTMTrainerTest::SetupTrainer |
( |
const std::string & |
network_spec, |
|
|
const std::string & |
model_name, |
|
|
const std::string & |
unicharset_file, |
|
|
const std::string & |
lstmf_file, |
|
|
bool |
recode, |
|
|
bool |
adam, |
|
|
double |
learning_rate, |
|
|
bool |
layer_specific, |
|
|
const std::string & |
kLang |
|
) |
| |
|
inlineprotected |
Definition at line 71 of file lstm_test.h.
78 ASSERT_TRUE(unicharset.
load_from_file(unicharset_name.c_str(),
false));
83 kLang, !recode, words, words, words,
false,
86 std::string checkpoint_path = model_path +
"_checkpoint";
87 trainer_.reset(
new LSTMTrainer(model_path.c_str(), checkpoint_path.c_str(),
90 absl::StrCat(kLang,
".traineddata")));
91 int net_mode = adam ?
NF_ADAM : 0;
94 if (adam) learning_rate *= 20.0;
96 EXPECT_TRUE(
trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1,
97 learning_rate, 0.9, 0.999));
101 LOG(
INFO) <<
"Setup network:" << model_name <<
"\n" ;
◆ SetupTrainerEng()
void tesseract::LSTMTrainerTest::SetupTrainerEng |
( |
const std::string & |
network_spec, |
|
|
const std::string & |
model_name, |
|
|
bool |
recode, |
|
|
bool |
adam |
|
) |
| |
|
inlineprotected |
Definition at line 66 of file lstm_test.h.
68 SetupTrainer(network_spec, model_name,
"eng/eng.unicharset",
69 "eng.Arial.exp0.lstmf", recode, adam, 5e-4,
false,
"eng");
◆ TessDataNameToPath()
◆ TestDataNameToPath()
◆ TestEncodeDecode()
void tesseract::LSTMTrainerTest::TestEncodeDecode |
( |
const std::string & |
lang, |
|
|
const std::string & |
str, |
|
|
bool |
recode |
|
) |
| |
|
inlineprotected |
Definition at line 169 of file lstm_test.h.
170 std::string unicharset_name = lang +
"/" + lang +
".unicharset";
171 std::string lstmf_name = lang +
".Arial_Unicode_MS.exp0.lstmf";
172 SetupTrainer(
"[1,1,0,32 Lbx100 O1c1]",
"bidi-lstm", unicharset_name,
173 lstmf_name, recode,
true, 5e-4,
true, lang);
175 EXPECT_TRUE(
trainer_->EncodeString(str.c_str(), &labels));
178 EXPECT_EQ(str, decoded_str);
◆ TestEncodeDecodeBoth()
◆ TestingNameToPath()
◆ TestIntMode()
double tesseract::LSTMTrainerTest::TestIntMode |
( |
int |
test_iterations | ) |
|
|
inlineprotected |
Definition at line 153 of file lstm_test.h.
163 EXPECT_LT(int_err, float_err + 1.0);
164 return int_err - float_err;
◆ TestIterations()
double tesseract::LSTMTrainerTest::TestIterations |
( |
int |
max_iterations | ) |
|
|
inlineprotected |
Definition at line 130 of file lstm_test.h.
132 int iteration =
trainer_->sample_iteration();
133 double mean_error = 0.0;
135 while (error_count < max_iterations) {
136 const ImageData& trainingdata =
137 *
trainer_->mutable_training_data()->GetPageBySerial(iteration);
138 NetworkIO fwd_outputs, targets;
139 if (
trainer_->PrepareForBackward(&trainingdata, &fwd_outputs, &targets) !=
144 trainer_->SetIteration(++iteration);
146 mean_error *= 100.0 / max_iterations;
147 LOG(
INFO) <<
"Tester error rate = " << mean_error <<
"\n" ;
◆ TrainIterations()
double tesseract::LSTMTrainerTest::TrainIterations |
( |
int |
max_iterations | ) |
|
|
inlineprotected |
Definition at line 104 of file lstm_test.h.
105 int iteration =
trainer_->training_iteration();
106 int iteration_limit = iteration + max_iterations;
107 double best_error = 100.0;
112 double mean_error = 0.0;
113 while (iteration < target_iteration && iteration < iteration_limit) {
115 iteration =
trainer_->training_iteration();
118 trainer_->MaintainCheckpoints(
nullptr, &log_str);
119 iteration =
trainer_->training_iteration();
122 LOG(
INFO) <<
"Best error = " << best_error <<
"\n" ;
123 LOG(
INFO) <<
"Mean error = " << mean_error <<
"\n" ;
124 if (mean_error < best_error) best_error = mean_error;
125 }
while (iteration < iteration_limit);
126 LOG(
INFO) <<
"Trainer error rate = " << best_error <<
"\n";
◆ trainer_
std::unique_ptr<LSTMTrainer> tesseract::LSTMTrainerTest::trainer_ |
|
protected |
The documentation for this class was generated from the following file: