tesseract  5.0.0-alpha-619-ge9db
mastertrainer_test.cc
Go to the documentation of this file.
1 // (C) Copyright 2017, Google Inc.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 // http://www.apache.org/licenses/LICENSE-2.0
6 // Unless required by applicable law or agreed to in writing, software
7 // distributed under the License is distributed on an "AS IS" BASIS,
8 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 // See the License for the specific language governing permissions and
10 // limitations under the License.
11 
12 // Although this is a trivial-looking test, it exercises a lot of code:
13 // SampleIterator has to correctly iterate over the correct characters, or
14 // it will fail.
15 // The canonical and cloud features computed by TrainingSampleSet need to
16 // be correct, along with the distance caches, organizing samples by font
17 // and class, indexing of features, distance calculations.
18 // IntFeatureDist has to work, or the canonical samples won't work.
19 // Mastertrainer has ability to read tr files and set itself up tested.
20 // Finally the serialize/deserialize test ensures that MasterTrainer,
21 // TrainingSampleSet, TrainingSample can all serialize/deserialize correctly
22 // enough to reproduce the same results.
23 
24 #include <string>
25 #include <utility>
26 #include <vector>
27 
28 #include "absl/strings/numbers.h" // for safe_strto32
29 #include "absl/strings/str_split.h" // for absl::StrSplit
30 
31 #include "include_gunit.h"
32 
34 #include "log.h" // for LOG
35 #include "unicharset.h"
36 #include "errorcounter.h"
37 #include "mastertrainer.h"
38 #include "shapeclassifier.h"
39 #include "shapetable.h"
40 #include "trainingsample.h"
41 #include "commontraining.h"
42 #include "tessopt.h" // tessoptind
43 
44 // Specs of the MockClassifier.
45 static const int kNumTopNErrs = 10;
46 static const int kNumTop2Errs = kNumTopNErrs + 20;
47 static const int kNumTop1Errs = kNumTop2Errs + 30;
48 static const int kNumTopTopErrs = kNumTop1Errs + 25;
49 static const int kNumNonReject = 1000;
50 static const int kNumCorrect = kNumNonReject - kNumTop1Errs;
51 // The total number of answers is given by the number of non-rejects plus
52 // all the multiple answers.
53 static const int kNumAnswers = kNumNonReject + 2 * (kNumTop2Errs - kNumTopNErrs) +
54  (kNumTop1Errs - kNumTop2Errs) +
55  (kNumTopTopErrs - kNumTop1Errs);
56 
57 #ifndef DISABLED_LEGACY_ENGINE
58 static bool safe_strto32(const std::string& str, int* pResult)
59 {
60  long n = strtol(str.c_str(), nullptr, 0);
61  *pResult = n;
62  return true;
63 }
64 #endif
65 
66 namespace tesseract {
67 
68 // Mock ShapeClassifier that cheats by looking at the correct answer, and
69 // creates a specific pattern of errors that can be tested.
71  public:
72  explicit MockClassifier(ShapeTable* shape_table)
73  : shape_table_(shape_table), num_done_(0), done_bad_font_(false) {
74  // Add a false font answer to the shape table. We pick a random unichar_id,
75  // add a new shape for it with a false font. Font must actually exist in
76  // the font table, but not match anything in the first 1000 samples.
77  false_unichar_id_ = 67;
78  false_shape_ = shape_table_->AddShape(false_unichar_id_, 25);
79  }
80  virtual ~MockClassifier() {}
81 
82  // Classifies the given [training] sample, writing to results.
83  // If debug is non-zero, then various degrees of classifier dependent debug
84  // information is provided.
85  // If keep_this (a shape index) is >= 0, then the results should always
86  // contain keep_this, and (if possible) anything of intermediate confidence.
87  // The return value is the number of classes saved in results.
88  virtual int ClassifySample(const TrainingSample& sample, Pix* page_pix,
89  int debug, UNICHAR_ID keep_this,
90  GenericVector<ShapeRating>* results) {
91  results->clear();
92  // Everything except the first kNumNonReject is a reject.
93  if (++num_done_ > kNumNonReject) return 0;
94 
95  int class_id = sample.class_id();
96  int font_id = sample.font_id();
97  int shape_id = shape_table_->FindShape(class_id, font_id);
98  // Get ids of some wrong answers.
99  int wrong_id1 = shape_id > 10 ? shape_id - 1 : shape_id + 1;
100  int wrong_id2 = shape_id > 10 ? shape_id - 2 : shape_id + 2;
101  if (num_done_ <= kNumTopNErrs) {
102  // The first kNumTopNErrs are top-n errors.
103  results->push_back(ShapeRating(wrong_id1, 1.0f));
104  } else if (num_done_ <= kNumTop2Errs) {
105  // The next kNumTop2Errs - kNumTopNErrs are top-2 errors.
106  results->push_back(ShapeRating(wrong_id1, 1.0f));
107  results->push_back(ShapeRating(wrong_id2, 0.875f));
108  results->push_back(ShapeRating(shape_id, 0.75f));
109  } else if (num_done_ <= kNumTop1Errs) {
110  // The next kNumTop1Errs - kNumTop2Errs are top-1 errors.
111  results->push_back(ShapeRating(wrong_id1, 1.0f));
112  results->push_back(ShapeRating(shape_id, 0.8f));
113  } else if (num_done_ <= kNumTopTopErrs) {
114  // The next kNumTopTopErrs - kNumTop1Errs are cases where the actual top
115  // is not correct, but do not count as a top-1 error because the rating
116  // is close enough to the top answer.
117  results->push_back(ShapeRating(wrong_id1, 1.0f));
118  results->push_back(ShapeRating(shape_id, 0.99f));
119  } else if (!done_bad_font_ && class_id == false_unichar_id_) {
120  // There is a single character with a bad font.
121  results->push_back(ShapeRating(false_shape_, 1.0f));
122  done_bad_font_ = true;
123  } else {
124  // Everything else is correct.
125  results->push_back(ShapeRating(shape_id, 1.0f));
126  }
127  return results->size();
128  }
129  // Provides access to the ShapeTable that this classifier works with.
130  virtual const ShapeTable* GetShapeTable() const { return shape_table_; }
131 
132  private:
133  // Borrowed pointer to the ShapeTable.
134  ShapeTable* shape_table_;
135  // Unichar_id of a random character that occurs after the first 60 samples.
136  int false_unichar_id_;
137  // Shape index of prepared false answer for false_unichar_id.
138  int false_shape_;
139  // The number of classifications we have processed.
140  int num_done_;
141  // True after the false font has been emitted.
142  bool done_bad_font_;
143 };
144 
145 } // namespace tesseract
146 
147 namespace {
148 
150 using tesseract::Shape;
153 
154 const double kMin1lDistance = 0.25;
155 
156 // The fixture for testing Tesseract.
157 class MasterTrainerTest : public testing::Test {
158 #ifndef DISABLED_LEGACY_ENGINE
159  protected:
160  void SetUp() {
161  std::locale::global(std::locale(""));
162  }
163 
164  std::string TestDataNameToPath(const std::string& name) {
165  return file::JoinPath(TESTING_DIR, name);
166  }
167  std::string TmpNameToPath(const std::string& name) {
168  return file::JoinPath(FLAGS_test_tmpdir, name);
169  }
170 
171  MasterTrainerTest() {
172  shape_table_ = nullptr;
173  master_trainer_ = nullptr;
174  }
175  ~MasterTrainerTest() {
176  delete master_trainer_;
177  delete shape_table_;
178  }
179 
180  // Initializes the master_trainer_ and shape_table_.
181  // if load_from_tmp, then reloads a master trainer that was saved by a
182  // previous call in which it was false.
183  void LoadMasterTrainer() {
184  FLAGS_output_trainer = TmpNameToPath("tmp_trainer").c_str();
185  FLAGS_F = file::JoinPath(LANGDATA_DIR, "font_properties").c_str();
186  FLAGS_X = TestDataNameToPath("eng.xheights").c_str();
187  FLAGS_U = TestDataNameToPath("eng.unicharset").c_str();
188  std::string tr_file_name(TestDataNameToPath("eng.Arial.exp0.tr"));
189  const char* argv[] = {tr_file_name.c_str()};
190  int argc = 1;
191  STRING file_prefix;
192  delete master_trainer_;
193  delete shape_table_;
194  shape_table_ = nullptr;
195  tessoptind = 0;
196  master_trainer_ =
197  LoadTrainingData(argc, argv, false, &shape_table_, &file_prefix);
198  EXPECT_TRUE(master_trainer_ != nullptr);
199  EXPECT_TRUE(shape_table_ != nullptr);
200  }
201 
202  // EXPECTs that the distance between I and l in Arial is 0 and that the
203  // distance to 1 is significantly not 0.
204  void VerifyIl1() {
205  // Find the font id for Arial.
206  int font_id = master_trainer_->GetFontInfoId("Arial");
207  EXPECT_GE(font_id, 0);
208  // Track down the characters we are interested in.
209  int unichar_I = master_trainer_->unicharset().unichar_to_id("I");
210  EXPECT_GT(unichar_I, 0);
211  int unichar_l = master_trainer_->unicharset().unichar_to_id("l");
212  EXPECT_GT(unichar_l, 0);
213  int unichar_1 = master_trainer_->unicharset().unichar_to_id("1");
214  EXPECT_GT(unichar_1, 0);
215  // Now get the shape ids.
216  int shape_I = shape_table_->FindShape(unichar_I, font_id);
217  EXPECT_GE(shape_I, 0);
218  int shape_l = shape_table_->FindShape(unichar_l, font_id);
219  EXPECT_GE(shape_l, 0);
220  int shape_1 = shape_table_->FindShape(unichar_1, font_id);
221  EXPECT_GE(shape_1, 0);
222 
223  float dist_I_l =
224  master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_l);
225  // No tolerance here. We expect that I and l should match exactly.
226  EXPECT_EQ(0.0f, dist_I_l);
227  float dist_l_I =
228  master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_I);
229  // BOTH ways.
230  EXPECT_EQ(0.0f, dist_l_I);
231 
232  // l/1 on the other hand should be distinct.
233  float dist_l_1 =
234  master_trainer_->ShapeDistance(*shape_table_, shape_l, shape_1);
235  EXPECT_GT(dist_l_1, kMin1lDistance);
236  float dist_1_l =
237  master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_l);
238  EXPECT_GT(dist_1_l, kMin1lDistance);
239 
240  // So should I/1.
241  float dist_I_1 =
242  master_trainer_->ShapeDistance(*shape_table_, shape_I, shape_1);
243  EXPECT_GT(dist_I_1, kMin1lDistance);
244  float dist_1_I =
245  master_trainer_->ShapeDistance(*shape_table_, shape_1, shape_I);
246  EXPECT_GT(dist_1_I, kMin1lDistance);
247  }
248 
249  // Objects declared here can be used by all tests in the test case for Foo.
250  ShapeTable* shape_table_;
251  MasterTrainer* master_trainer_;
252 #endif
253 };
254 
255 // Tests that the MasterTrainer correctly loads its data and reaches the correct
256 // conclusion over the distance between Arial I l and 1.
257 TEST_F(MasterTrainerTest, Il1Test) {
258 #ifdef DISABLED_LEGACY_ENGINE
259  // Skip test because LoadTrainingData is missing.
260  GTEST_SKIP();
261 #else
262  // Initialize the master_trainer_ and load the Arial tr file.
263  LoadMasterTrainer();
264  VerifyIl1();
265 #endif
266 }
267 
268 // Tests the ErrorCounter using a MockClassifier to check that it counts
269 // error categories correctly.
270 TEST_F(MasterTrainerTest, ErrorCounterTest) {
271 #ifdef DISABLED_LEGACY_ENGINE
272  // Skip test because LoadTrainingData is missing.
273  GTEST_SKIP();
274 #else
275  // Initialize the master_trainer_ from the saved tmp file.
276  LoadMasterTrainer();
277  // Add the space character to the shape_table_ if not already present to
278  // count junk.
279  if (shape_table_->FindShape(0, -1) < 0) shape_table_->AddShape(0, 0);
280  // Make a mock classifier.
281  tesseract::ShapeClassifier* shape_classifier =
282  new tesseract::MockClassifier(shape_table_);
283  // Get the accuracy report.
284  STRING accuracy_report;
285  master_trainer_->TestClassifierOnSamples(tesseract::CT_UNICHAR_TOP1_ERR, 0,
286  false, shape_classifier,
287  &accuracy_report);
288  LOG(INFO) << accuracy_report.c_str();
289  std::string result_string = accuracy_report.c_str();
290  std::vector<std::string> results =
291  absl::StrSplit(result_string, '\t', absl::SkipEmpty());
292  EXPECT_EQ(tesseract::CT_SIZE + 1, results.size());
293  int result_values[tesseract::CT_SIZE];
294  for (int i = 0; i < tesseract::CT_SIZE; ++i) {
295  EXPECT_TRUE(safe_strto32(results[i + 1], &result_values[i]));
296  }
297  // These tests are more-or-less immune to additions to the number of
298  // categories or changes in the training data.
299  int num_samples = master_trainer_->GetSamples()->num_raw_samples();
300  EXPECT_EQ(kNumCorrect, result_values[tesseract::CT_UNICHAR_TOP_OK]);
301  EXPECT_EQ(1, result_values[tesseract::CT_FONT_ATTR_ERR]);
302  EXPECT_EQ(kNumTopTopErrs, result_values[tesseract::CT_UNICHAR_TOPTOP_ERR]);
303  EXPECT_EQ(kNumTop1Errs, result_values[tesseract::CT_UNICHAR_TOP1_ERR]);
304  EXPECT_EQ(kNumTop2Errs, result_values[tesseract::CT_UNICHAR_TOP2_ERR]);
305  EXPECT_EQ(kNumTopNErrs, result_values[tesseract::CT_UNICHAR_TOPN_ERR]);
306  // Each of the TOPTOP errs also counts as a multi-unichar.
307  EXPECT_EQ(kNumTopTopErrs - kNumTop1Errs,
308  result_values[tesseract::CT_OK_MULTI_UNICHAR]);
309  EXPECT_EQ(num_samples - kNumNonReject, result_values[tesseract::CT_REJECT]);
310  EXPECT_EQ(kNumAnswers, result_values[tesseract::CT_NUM_RESULTS]);
311 
312  delete shape_classifier;
313 #endif
314 }
315 
316 } // namespace.
file::JoinPath
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:43
string
std::string string
Definition: equationdetect_test.cc:21
INFO
Definition: log.h:29
tesseract::MockClassifier::GetShapeTable
virtual const ShapeTable * GetShapeTable() const
Definition: mastertrainer_test.cc:130
tesseract::MockClassifier::MockClassifier
MockClassifier(ShapeTable *shape_table)
Definition: mastertrainer_test.cc:72
commontraining.h
tesseract::CT_SIZE
Definition: errorcounter.h:89
tesseract::Shape
Definition: shapetable.h:184
tesseract::MockClassifier
Definition: mastertrainer_test.cc:70
errorcounter.h
STRING
Definition: strngs.h:45
tesseract::UnicharAndFonts
Definition: shapetable.h:159
tesseract::CT_REJECT
Definition: errorcounter.h:81
mastertrainer.h
include_gunit.h
tesseract::TEST_F
TEST_F(EquationFinderTest, IdentifySpecialText)
Definition: equationdetect_test.cc:181
tesseract::CT_UNICHAR_TOPN_ERR
Definition: errorcounter.h:76
tesseract::MockClassifier::ClassifySample
virtual int ClassifySample(const TrainingSample &sample, Pix *page_pix, int debug, UNICHAR_ID keep_this, GenericVector< ShapeRating > *results)
Definition: mastertrainer_test.cc:88
genericvector.h
GenericVector::push_back
int push_back(T object)
Definition: genericvector.h:799
tesseract::ShapeClassifier
Definition: shapeclassifier.h:43
FLAGS_test_tmpdir
const char * FLAGS_test_tmpdir
Definition: include_gunit.h:20
STRING::c_str
const char * c_str() const
Definition: strngs.cpp:192
unicharset.h
tesseract::CT_OK_MULTI_UNICHAR
Definition: errorcounter.h:78
trainingsample.h
tesseract::CT_UNICHAR_TOP_OK
Definition: errorcounter.h:70
shapetable.h
tesseract::ShapeRating
Definition: shapetable.h:92
tesseract
Definition: baseapi.h:65
UNICHAR_ID
int UNICHAR_ID
Definition: unichar.h:36
sample
Definition: cluster.h:31
tesseract::MasterTrainer
Definition: mastertrainer.h:69
GenericVector
Definition: baseapi.h:40
tessoptind
int tessoptind
Definition: tessopt.cpp:23
shapeclassifier.h
tesseract::LoadTrainingData
MasterTrainer * LoadTrainingData(int argc, const char *const *argv, bool replication, ShapeTable **shape_table, STRING *file_prefix)
Definition: commontraining.cpp:211
tesseract::TrainingSample
Definition: trainingsample.h:53
GenericVector::clear
void clear()
Definition: genericvector.h:857
tesseract::CT_NUM_RESULTS
Definition: errorcounter.h:84
tesseract::ShapeTable::FindShape
int FindShape(int unichar_id, int font_id) const
Definition: shapetable.cpp:386
tesseract::ShapeTable
Definition: shapetable.h:261
tesseract::ShapeTable::AddShape
int AddShape(int unichar_id, int font_id)
Definition: shapetable.cpp:336
tesseract::CT_FONT_ATTR_ERR
Definition: errorcounter.h:82
log.h
LOG
Definition: cleanapi_test.cc:19
tesseract::CT_UNICHAR_TOPTOP_ERR
Definition: errorcounter.h:77
tessopt.h
GenericVector::size
int size() const
Definition: genericvector.h:71
tesseract::CT_UNICHAR_TOP1_ERR
Definition: errorcounter.h:74
tesseract::MockClassifier::~MockClassifier
virtual ~MockClassifier()
Definition: mastertrainer_test.cc:80
tesseract::CT_UNICHAR_TOP2_ERR
Definition: errorcounter.h:75