tesseract  5.0.0-alpha-619-ge9db
tesseract::LSTMTrainerTest Class Reference

#include <lstm_test.h>

Inheritance diagram for tesseract::LSTMTrainerTest:

Protected Member Functions

void SetUp ()
 
 LSTMTrainerTest ()
 
std::string TestDataNameToPath (const std::string &name)
 
std::string TessDataNameToPath (const std::string &name)
 
std::string TestingNameToPath (const std::string &name)
 
void SetupTrainerEng (const std::string &network_spec, const std::string &model_name, bool recode, bool adam)
 
void SetupTrainer (const std::string &network_spec, const std::string &model_name, const std::string &unicharset_file, const std::string &lstmf_file, bool recode, bool adam, double learning_rate, bool layer_specific, const std::string &kLang)
 
double TrainIterations (int max_iterations)
 
double TestIterations (int max_iterations)
 
double TestIntMode (int test_iterations)
 
void TestEncodeDecode (const std::string &lang, const std::string &str, bool recode)
 
void TestEncodeDecodeBoth (const std::string &lang, const std::string &str)
 

Protected Attributes

std::unique_ptr< LSTMTrainertrainer_
 

Detailed Description

Definition at line 46 of file lstm_test.h.

Constructor & Destructor Documentation

◆ LSTMTrainerTest()

tesseract::LSTMTrainerTest::LSTMTrainerTest ( )
inlineprotected

Definition at line 52 of file lstm_test.h.

52 {}

Member Function Documentation

◆ SetUp()

void tesseract::LSTMTrainerTest::SetUp ( )
inlineprotected

Definition at line 48 of file lstm_test.h.

48  {
49  std::locale::global(std::locale(""));
50  }

◆ SetupTrainer()

void tesseract::LSTMTrainerTest::SetupTrainer ( const std::string network_spec,
const std::string model_name,
const std::string unicharset_file,
const std::string lstmf_file,
bool  recode,
bool  adam,
double  learning_rate,
bool  layer_specific,
const std::string kLang 
)
inlineprotected

Definition at line 71 of file lstm_test.h.

74  {
75 // constexpr char kLang[] = "eng"; // Exact value doesn't matter.
76  std::string unicharset_name = TestDataNameToPath(unicharset_file);
77  UNICHARSET unicharset;
78  ASSERT_TRUE(unicharset.load_from_file(unicharset_name.c_str(), false));
79  std::string script_dir = file::JoinPath(
80  LANGDATA_DIR, "");
82  EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, "", FLAGS_test_tmpdir,
83  kLang, !recode, words, words, words, false,
84  nullptr, nullptr));
85  std::string model_path = file::JoinPath(FLAGS_test_tmpdir, model_name);
86  std::string checkpoint_path = model_path + "_checkpoint";
87  trainer_.reset(new LSTMTrainer(model_path.c_str(), checkpoint_path.c_str(),
88  0, 0));
89  trainer_->InitCharSet(file::JoinPath(FLAGS_test_tmpdir, kLang,
90  absl::StrCat(kLang, ".traineddata")));
91  int net_mode = adam ? NF_ADAM : 0;
92  // Adam needs a higher learning rate, due to not multiplying the effective
93  // rate by 1/(1-momentum).
94  if (adam) learning_rate *= 20.0;
95  if (layer_specific) net_mode |= NF_LAYER_SPECIFIC_LR;
96  EXPECT_TRUE(trainer_->InitNetwork(network_spec.c_str(), -1, net_mode, 0.1,
97  learning_rate, 0.9, 0.999));
98  GenericVector<STRING> filenames;
99  filenames.push_back(STRING(TestDataNameToPath(lstmf_file).c_str()));
100  EXPECT_TRUE(trainer_->LoadAllTrainingData(filenames, CS_SEQUENTIAL, false));
101  LOG(INFO) << "Setup network:" << model_name << "\n" ;
102  }

◆ SetupTrainerEng()

void tesseract::LSTMTrainerTest::SetupTrainerEng ( const std::string network_spec,
const std::string model_name,
bool  recode,
bool  adam 
)
inlineprotected

Definition at line 66 of file lstm_test.h.

67  {
68  SetupTrainer(network_spec, model_name, "eng/eng.unicharset",
69  "eng.Arial.exp0.lstmf", recode, adam, 5e-4, false, "eng");
70  }

◆ TessDataNameToPath()

std::string tesseract::LSTMTrainerTest::TessDataNameToPath ( const std::string name)
inlineprotected

Definition at line 57 of file lstm_test.h.

57  {
58  return file::JoinPath(TESSDATA_DIR,
59  "" + name);
60  }

◆ TestDataNameToPath()

std::string tesseract::LSTMTrainerTest::TestDataNameToPath ( const std::string name)
inlineprotected

Definition at line 53 of file lstm_test.h.

53  {
54  return file::JoinPath(TESTDATA_DIR,
55  "" + name);
56  }

◆ TestEncodeDecode()

void tesseract::LSTMTrainerTest::TestEncodeDecode ( const std::string lang,
const std::string str,
bool  recode 
)
inlineprotected

Definition at line 169 of file lstm_test.h.

169  {
170  std::string unicharset_name = lang + "/" + lang + ".unicharset";
171  std::string lstmf_name = lang + ".Arial_Unicode_MS.exp0.lstmf";
172  SetupTrainer("[1,1,0,32 Lbx100 O1c1]", "bidi-lstm", unicharset_name,
173  lstmf_name, recode, true, 5e-4, true, lang);
174  GenericVector<int> labels;
175  EXPECT_TRUE(trainer_->EncodeString(str.c_str(), &labels));
176  STRING decoded = trainer_->DecodeLabels(labels);
177  std::string decoded_str(&decoded[0], decoded.length());
178  EXPECT_EQ(str, decoded_str);
179  }

◆ TestEncodeDecodeBoth()

void tesseract::LSTMTrainerTest::TestEncodeDecodeBoth ( const std::string lang,
const std::string str 
)
inlineprotected

Definition at line 181 of file lstm_test.h.

181  {
182  TestEncodeDecode(lang, str, false);
183  TestEncodeDecode(lang, str, true);
184  }

◆ TestingNameToPath()

std::string tesseract::LSTMTrainerTest::TestingNameToPath ( const std::string name)
inlineprotected

Definition at line 61 of file lstm_test.h.

61  {
62  return file::JoinPath(TESTING_DIR,
63  "" + name);
64  }

◆ TestIntMode()

double tesseract::LSTMTrainerTest::TestIntMode ( int  test_iterations)
inlineprotected

Definition at line 153 of file lstm_test.h.

153  {
154  GenericVector<char> trainer_data;
155  EXPECT_TRUE(trainer_->SaveTrainingDump(NO_BEST_TRAINER, trainer_.get(),
156  &trainer_data));
157  // Get the error on the next few iterations in float mode.
158  double float_err = TestIterations(test_iterations);
159  // Restore the dump, convert to int and test error on that.
160  EXPECT_TRUE(trainer_->ReadTrainingDump(trainer_data, trainer_.get()));
161  trainer_->ConvertToInt();
162  double int_err = TestIterations(test_iterations);
163  EXPECT_LT(int_err, float_err + 1.0);
164  return int_err - float_err;
165  }

◆ TestIterations()

double tesseract::LSTMTrainerTest::TestIterations ( int  max_iterations)
inlineprotected

Definition at line 130 of file lstm_test.h.

130  {
131  CHECK_GT(max_iterations, 0);
132  int iteration = trainer_->sample_iteration();
133  double mean_error = 0.0;
134  int error_count = 0;
135  while (error_count < max_iterations) {
136  const ImageData& trainingdata =
137  *trainer_->mutable_training_data()->GetPageBySerial(iteration);
138  NetworkIO fwd_outputs, targets;
139  if (trainer_->PrepareForBackward(&trainingdata, &fwd_outputs, &targets) !=
140  UNENCODABLE) {
141  mean_error += trainer_->NewSingleError(ET_CHAR_ERROR);
142  ++error_count;
143  }
144  trainer_->SetIteration(++iteration);
145  }
146  mean_error *= 100.0 / max_iterations;
147  LOG(INFO) << "Tester error rate = " << mean_error << "\n" ;
148  return mean_error;
149  }

◆ TrainIterations()

double tesseract::LSTMTrainerTest::TrainIterations ( int  max_iterations)
inlineprotected

Definition at line 104 of file lstm_test.h.

104  {
105  int iteration = trainer_->training_iteration();
106  int iteration_limit = iteration + max_iterations;
107  double best_error = 100.0;
108  do {
109  STRING log_str;
110  int target_iteration = iteration + kBatchIterations;
111  // Train a few.
112  double mean_error = 0.0;
113  while (iteration < target_iteration && iteration < iteration_limit) {
114  trainer_->TrainOnLine(trainer_.get(), false);
115  iteration = trainer_->training_iteration();
116  mean_error += trainer_->LastSingleError(ET_CHAR_ERROR);
117  }
118  trainer_->MaintainCheckpoints(nullptr, &log_str);
119  iteration = trainer_->training_iteration();
120  mean_error *= 100.0 / kBatchIterations;
121  LOG(INFO) << log_str.c_str();
122  LOG(INFO) << "Best error = " << best_error << "\n" ;
123  LOG(INFO) << "Mean error = " << mean_error << "\n" ;
124  if (mean_error < best_error) best_error = mean_error;
125  } while (iteration < iteration_limit);
126  LOG(INFO) << "Trainer error rate = " << best_error << "\n";
127  return best_error;
128  }

Member Data Documentation

◆ trainer_

std::unique_ptr<LSTMTrainer> tesseract::LSTMTrainerTest::trainer_
protected

Definition at line 186 of file lstm_test.h.


The documentation for this class was generated from the following file:
UNICHARSET::load_from_file
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:378
tesseract::LSTMTrainerTest::TestEncodeDecode
void TestEncodeDecode(const std::string &lang, const std::string &str, bool recode)
Definition: lstm_test.h:169
file::JoinPath
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:43
string
std::string string
Definition: equationdetect_test.cc:21
INFO
Definition: log.h:29
tesseract::CombineLangModel
int CombineLangModel(const UNICHARSET &unicharset, const std::string &script_dir, const std::string &version_str, const std::string &output_dir, const std::string &lang, bool pass_through_recoder, const GenericVector< STRING > &words, const GenericVector< STRING > &puncs, const GenericVector< STRING > &numbers, bool lang_is_rtl, FileReader reader, FileWriter writer)
Definition: lang_model_helpers.cpp:185
tesseract::LSTMTrainerTest::TestDataNameToPath
std::string TestDataNameToPath(const std::string &name)
Definition: lstm_test.h:53
tesseract::LSTMTrainerTest::SetupTrainer
void SetupTrainer(const std::string &network_spec, const std::string &model_name, const std::string &unicharset_file, const std::string &lstmf_file, bool recode, bool adam, double learning_rate, bool layer_specific, const std::string &kLang)
Definition: lstm_test.h:71
STRING
Definition: strngs.h:45
tesseract::CS_SEQUENTIAL
Definition: imagedata.h:48
GenericVector::push_back
int push_back(T object)
Definition: genericvector.h:799
tesseract::NF_ADAM
Definition: network.h:88
FLAGS_test_tmpdir
const char * FLAGS_test_tmpdir
Definition: include_gunit.h:20
STRING::c_str
const char * c_str() const
Definition: strngs.cpp:192
UNICHARSET
Definition: unicharset.h:145
tesseract::LSTMTrainerTest::trainer_
std::unique_ptr< LSTMTrainer > trainer_
Definition: lstm_test.h:186
tesseract::LSTMTrainerTest::TestIterations
double TestIterations(int max_iterations)
Definition: lstm_test.h:130
tesseract::ET_CHAR_ERROR
Definition: lstmtrainer.h:41
GenericVector< STRING >
STRING::length
int32_t length() const
Definition: strngs.cpp:187
tesseract::kBatchIterations
const int kBatchIterations
Definition: lstm_test.h:37
tesseract::NF_LAYER_SPECIFIC_LR
Definition: network.h:87
CHECK_GT
#define CHECK_GT(test, value)
Definition: include_gunit.h:59
LOG
Definition: cleanapi_test.cc:19
tesseract::UNENCODABLE
Definition: lstmtrainer.h:50
tesseract::NO_BEST_TRAINER
Definition: lstmtrainer.h:58