tesseract  5.0.0-alpha-619-ge9db
lstm_test.h
Go to the documentation of this file.
1 // (C) Copyright 2017, Google Inc.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 // http://www.apache.org/licenses/LICENSE-2.0
6 // Unless required by applicable law or agreed to in writing, software
7 // distributed under the License is distributed on an "AS IS" BASIS,
8 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 // See the License for the specific language governing permissions and
10 // limitations under the License.
11 
12 #ifndef TESSERACT_UNITTEST_LSTM_TEST_H_
13 #define TESSERACT_UNITTEST_LSTM_TEST_H_
14 
15 #include <memory>
16 #include <string>
17 #include <utility>
18 
19 #include "include_gunit.h"
20 
21 #include "absl/strings/str_cat.h"
22 #include "tprintf.h"
23 #include <tesseract/helpers.h>
24 
25 #include "functions.h"
26 #include "lang_model_helpers.h"
27 #include "log.h" // for LOG
28 #include "lstmtrainer.h"
29 #include "unicharset.h"
30 
31 namespace tesseract {
32 
33 #if DEBUG_DETAIL == 0
34 // Number of iterations to run all the trainers.
35 const int kTrainerIterations = 600;
36 // Number of iterations between accuracy checks.
37 const int kBatchIterations = 100;
38 #else
39 // Number of iterations to run all the trainers.
40 const int kTrainerIterations = 2;
41 // Number of iterations between accuracy checks.
42 const int kBatchIterations = 1;
43 #endif
44 
45 // The fixture for testing LSTMTrainer.
46 class LSTMTrainerTest : public testing::Test {
47  protected:
48  void SetUp() {
49  std::locale::global(std::locale(""));
50  }
51 
54  return file::JoinPath(TESTDATA_DIR,
55  "" + name);
56  }
58  return file::JoinPath(TESSDATA_DIR,
59  "" + name);
60  }
62  return file::JoinPath(TESTING_DIR,
63  "" + name);
64  }
65 
66  void SetupTrainerEng(const std::string& network_spec, const std::string& model_name,
67  bool recode, bool adam) {
68  SetupTrainer(network_spec, model_name, "eng/eng.unicharset",
69  "eng.Arial.exp0.lstmf", recode, adam, 5e-4, false, "eng");
70  }
71  void SetupTrainer(const std::string& network_spec, const std::string& model_name,
72  const std::string& unicharset_file, const std::string& lstmf_file,
73  bool recode, bool adam, double learning_rate,
74  bool layer_specific, const std::string& kLang) {
75 // constexpr char kLang[] = "eng"; // Exact value doesn't matter.
76  std::string unicharset_name = TestDataNameToPath(unicharset_file);
77  UNICHARSET unicharset;
78  ASSERT_TRUE(unicharset.load_from_file(unicharset_name.c_str(), false));
79  std::string script_dir = file::JoinPath(
80  LANGDATA_DIR, "");
82  EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, "", FLAGS_test_tmpdir,
83  kLang, !recode, words, words, words, false,
84  nullptr, nullptr));
85  std::string model_path = file::JoinPath(FLAGS_test_tmpdir, model_name);
86  std::string checkpoint_path = model_path + "_checkpoint";
87  trainer_.reset(new LSTMTrainer(model_path.c_str(), checkpoint_path.c_str(),
88  0, 0));
89  trainer_->InitCharSet(file::JoinPath(FLAGS_test_tmpdir, kLang,
90  absl::StrCat(kLang, ".traineddata")));
91  int net_mode = adam ? NF_ADAM : 0;
92  // Adam needs a higher learning rate, due to not multiplying the effective
93  // rate by 1/(1-momentum).
94  if (adam) learning_rate *= 20.0;
95  if (layer_specific) net_mode |= NF_LAYER_SPECIFIC_LR;
96  EXPECT_TRUE(trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1,
97  learning_rate, 0.9, 0.999));
98  GenericVector<STRING> filenames;
99  filenames.push_back(STRING(TestDataNameToPath(lstmf_file).c_str()));
100  EXPECT_TRUE(trainer_->LoadAllTrainingData(filenames, CS_SEQUENTIAL, false));
101  LOG(INFO) << "Setup network:" << model_name << "\n" ;
102  }
103  // Trains for a given number of iterations and returns the char error rate.
104  double TrainIterations(int max_iterations) {
105  int iteration = trainer_->training_iteration();
106  int iteration_limit = iteration + max_iterations;
107  double best_error = 100.0;
108  do {
109  STRING log_str;
110  int target_iteration = iteration + kBatchIterations;
111  // Train a few.
112  double mean_error = 0.0;
113  while (iteration < target_iteration && iteration < iteration_limit) {
114  trainer_->TrainOnLine(trainer_.get(), false);
115  iteration = trainer_->training_iteration();
116  mean_error += trainer_->LastSingleError(ET_CHAR_ERROR);
117  }
118  trainer_->MaintainCheckpoints(nullptr, &log_str);
119  iteration = trainer_->training_iteration();
120  mean_error *= 100.0 / kBatchIterations;
121  LOG(INFO) << log_str.c_str();
122  LOG(INFO) << "Best error = " << best_error << "\n" ;
123  LOG(INFO) << "Mean error = " << mean_error << "\n" ;
124  if (mean_error < best_error) best_error = mean_error;
125  } while (iteration < iteration_limit);
126  LOG(INFO) << "Trainer error rate = " << best_error << "\n";
127  return best_error;
128  }
129  // Tests for a given number of iterations and returns the char error rate.
130  double TestIterations(int max_iterations) {
131  CHECK_GT(max_iterations, 0);
132  int iteration = trainer_->sample_iteration();
133  double mean_error = 0.0;
134  int error_count = 0;
135  while (error_count < max_iterations) {
136  const ImageData& trainingdata =
137  *trainer_->mutable_training_data()->GetPageBySerial(iteration);
138  NetworkIO fwd_outputs, targets;
139  if (trainer_->PrepareForBackward(&trainingdata, &fwd_outputs, &targets) !=
140  UNENCODABLE) {
141  mean_error += trainer_->NewSingleError(ET_CHAR_ERROR);
142  ++error_count;
143  }
144  trainer_->SetIteration(++iteration);
145  }
146  mean_error *= 100.0 / max_iterations;
147  LOG(INFO) << "Tester error rate = " << mean_error << "\n" ;
148  return mean_error;
149  }
150  // Tests that the current trainer_ can be converted to int mode and still gets
151  // within 1% of the error rate. Returns the increase in error from float to
152  // int.
153  double TestIntMode(int test_iterations) {
154  GenericVector<char> trainer_data;
155  EXPECT_TRUE(trainer_->SaveTrainingDump(NO_BEST_TRAINER, trainer_.get(),
156  &trainer_data));
157  // Get the error on the next few iterations in float mode.
158  double float_err = TestIterations(test_iterations);
159  // Restore the dump, convert to int and test error on that.
160  EXPECT_TRUE(trainer_->ReadTrainingDump(trainer_data, trainer_.get()));
161  trainer_->ConvertToInt();
162  double int_err = TestIterations(test_iterations);
163  EXPECT_LT(int_err, float_err + 1.0);
164  return int_err - float_err;
165  }
166  // Sets up a trainer with the given language and given recode+ctc condition.
167  // It then verifies that the given str encodes and decodes back to the same
168  // string.
169  void TestEncodeDecode(const std::string& lang, const std::string& str, bool recode) {
170  std::string unicharset_name = lang + "/" + lang + ".unicharset";
171  std::string lstmf_name = lang + ".Arial_Unicode_MS.exp0.lstmf";
172  SetupTrainer("[1,1,0,32 Lbx100 O1c1]", "bidi-lstm", unicharset_name,
173  lstmf_name, recode, true, 5e-4, true, lang);
174  GenericVector<int> labels;
175  EXPECT_TRUE(trainer_->EncodeString(str.c_str(), &labels));
176  STRING decoded = trainer_->DecodeLabels(labels);
177  std::string decoded_str(&decoded[0], decoded.length());
178  EXPECT_EQ(str, decoded_str);
179  }
180  // Calls TestEncodeDeode with both recode on and off.
181  void TestEncodeDecodeBoth(const std::string& lang, const std::string& str) {
182  TestEncodeDecode(lang, str, false);
183  TestEncodeDecode(lang, str, true);
184  }
185 
186  std::unique_ptr<LSTMTrainer> trainer_;
187 };
188 
189 } // namespace tesseract.
190 
191 #endif // THIRD_PARTY_TESSERACT_UNITTEST_LSTM_TEST_H_
UNICHARSET::load_from_file
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:378
tesseract::LSTMTrainerTest::TestEncodeDecode
void TestEncodeDecode(const std::string &lang, const std::string &str, bool recode)
Definition: lstm_test.h:169
file::JoinPath
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:43
string
std::string string
Definition: equationdetect_test.cc:21
INFO
Definition: log.h:29
tesseract::CombineLangModel
int CombineLangModel(const UNICHARSET &unicharset, const std::string &script_dir, const std::string &version_str, const std::string &output_dir, const std::string &lang, bool pass_through_recoder, const GenericVector< STRING > &words, const GenericVector< STRING > &puncs, const GenericVector< STRING > &numbers, bool lang_is_rtl, FileReader reader, FileWriter writer)
Definition: lang_model_helpers.cpp:185
tesseract::LSTMTrainerTest::TestDataNameToPath
std::string TestDataNameToPath(const std::string &name)
Definition: lstm_test.h:53
tesseract::kTrainerIterations
const int kTrainerIterations
Definition: lstm_test.h:35
tesseract::LSTMTrainerTest::TestEncodeDecodeBoth
void TestEncodeDecodeBoth(const std::string &lang, const std::string &str)
Definition: lstm_test.h:181
tesseract::LSTMTrainerTest::TestingNameToPath
std::string TestingNameToPath(const std::string &name)
Definition: lstm_test.h:61
lstmtrainer.h
tesseract::LSTMTrainerTest::SetupTrainer
void SetupTrainer(const std::string &network_spec, const std::string &model_name, const std::string &unicharset_file, const std::string &lstmf_file, bool recode, bool adam, double learning_rate, bool layer_specific, const std::string &kLang)
Definition: lstm_test.h:71
tesseract::LSTMTrainerTest::TrainIterations
double TrainIterations(int max_iterations)
Definition: lstm_test.h:104
STRING
Definition: strngs.h:45
tesseract::LSTMTrainerTest::TessDataNameToPath
std::string TessDataNameToPath(const std::string &name)
Definition: lstm_test.h:57
include_gunit.h
tesseract::CS_SEQUENTIAL
Definition: imagedata.h:48
tesseract::ImageData
Definition: imagedata.h:104
lang_model_helpers.h
GenericVector::push_back
int push_back(T object)
Definition: genericvector.h:799
tesseract::NF_ADAM
Definition: network.h:88
FLAGS_test_tmpdir
const char * FLAGS_test_tmpdir
Definition: include_gunit.h:20
STRING::c_str
const char * c_str() const
Definition: strngs.cpp:192
unicharset.h
UNICHARSET
Definition: unicharset.h:145
tesseract::NetworkIO
Definition: networkio.h:39
tesseract::LSTMTrainerTest::trainer_
std::unique_ptr< LSTMTrainer > trainer_
Definition: lstm_test.h:186
helpers.h
tesseract
Definition: baseapi.h:65
tesseract::LSTMTrainerTest::TestIterations
double TestIterations(int max_iterations)
Definition: lstm_test.h:130
tesseract::ET_CHAR_ERROR
Definition: lstmtrainer.h:41
tprintf.h
GenericVector< STRING >
STRING::length
int32_t length() const
Definition: strngs.cpp:187
tesseract::LSTMTrainer
Definition: lstmtrainer.h:79
tesseract::LSTMTrainerTest::LSTMTrainerTest
LSTMTrainerTest()
Definition: lstm_test.h:52
tesseract::LSTMTrainerTest
Definition: lstm_test.h:46
tesseract::kBatchIterations
const int kBatchIterations
Definition: lstm_test.h:37
functions.h
tesseract::NF_LAYER_SPECIFIC_LR
Definition: network.h:87
log.h
CHECK_GT
#define CHECK_GT(test, value)
Definition: include_gunit.h:59
LOG
Definition: cleanapi_test.cc:19
tesseract::UNENCODABLE
Definition: lstmtrainer.h:50
tesseract::LSTMTrainerTest::TestIntMode
double TestIntMode(int test_iterations)
Definition: lstm_test.h:153
tesseract::NO_BEST_TRAINER
Definition: lstmtrainer.h:58
tesseract::LSTMTrainerTest::SetupTrainerEng
void SetupTrainerEng(const std::string &network_spec, const std::string &model_name, bool recode, bool adam)
Definition: lstm_test.h:66
tesseract::LSTMTrainerTest::SetUp
void SetUp()
Definition: lstm_test.h:48