tesseract  5.0.0-alpha-619-ge9db
lang_model_test.cc
Go to the documentation of this file.
1 // (C) Copyright 2017, Google Inc.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 // http://www.apache.org/licenses/LICENSE-2.0
6 // Unless required by applicable law or agreed to in writing, software
7 // distributed under the License is distributed on an "AS IS" BASIS,
8 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 // See the License for the specific language governing permissions and
10 // limitations under the License.
11 
12 #include <string> // for std::string
13 
14 #include "absl/strings/str_cat.h"
15 
16 #include "gmock/gmock.h" // for testing::ElementsAreArray
17 
18 #include "include_gunit.h"
19 #include "lang_model_helpers.h"
20 #include "log.h" // for LOG
21 #include "lstmtrainer.h"
23 
24 namespace tesseract {
25 namespace {
26 
27 std::string TestDataNameToPath(const std::string& name) {
28  return file::JoinPath(TESTING_DIR, name);
29 }
30 
31 // This is an integration test that verifies that CombineLangModel works to
32 // the extent that an LSTMTrainer can be initialized with the result, and it
33 // can encode strings. More importantly, the test verifies that adding an extra
34 // character to the unicharset does not change the encoding of strings.
35 TEST(LangModelTest, AddACharacter) {
36  constexpr char kTestString[] = "Simple ASCII string to encode !@#$%&";
37  constexpr char kTestStringRupees[] = "ASCII string with Rupee symbol ₹";
38  // Setup the arguments.
39  std::string script_dir = LANGDATA_DIR;
40  std::string eng_dir = file::JoinPath(script_dir, "eng");
41  std::string unicharset_path = TestDataNameToPath("eng_beam.unicharset");
42  UNICHARSET unicharset;
43  EXPECT_TRUE(unicharset.load_from_file(unicharset_path.c_str()));
44  std::string version_str = "TestVersion";
45  std::string output_dir = FLAGS_test_tmpdir;
46  LOG(INFO) << "Output dir=" << output_dir << "\n";
47  std::string lang1 = "eng";
48  bool pass_through_recoder = false;
49  GenericVector<STRING> words, puncs, numbers;
50  // If these reads fail, we get a warning message and an empty list of words.
51  ReadFile(file::JoinPath(eng_dir, "eng.wordlist"), nullptr)
52  .split('\n', &words);
53  EXPECT_GT(words.size(), 0);
54  ReadFile(file::JoinPath(eng_dir, "eng.punc"), nullptr).split('\n', &puncs);
55  EXPECT_GT(puncs.size(), 0);
56  ReadFile(file::JoinPath(eng_dir, "eng.numbers"), nullptr)
57  .split('\n', &numbers);
58  EXPECT_GT(numbers.size(), 0);
59  bool lang_is_rtl = false;
60  // Generate the traineddata file.
61  EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, version_str, output_dir,
62  lang1, pass_through_recoder, words, puncs,
63  numbers, lang_is_rtl, nullptr, nullptr));
64  // Init a trainer with it, and encode kTestString.
65  std::string traineddata1 =
66  file::JoinPath(output_dir, lang1, absl::StrCat(lang1, ".traineddata"));
67  LSTMTrainer trainer1;
68  trainer1.InitCharSet(traineddata1);
69  GenericVector<int> labels1;
70  EXPECT_TRUE(trainer1.EncodeString(kTestString, &labels1));
71  STRING test1_decoded = trainer1.DecodeLabels(labels1);
72  std::string test1_str(&test1_decoded[0], test1_decoded.length());
73  LOG(INFO) << "Labels1=" << test1_str << "\n";
74 
75  // Add a new character to the unicharset and try again.
76  int size_before = unicharset.size();
77  unicharset.unichar_insert("₹");
78  SetupBasicProperties(/*report_errors*/ true, /*decompose (NFD)*/ false,
79  &unicharset);
80  EXPECT_EQ(size_before + 1, unicharset.size());
81  // Generate the traineddata file.
82  std::string lang2 = "extended";
83  EXPECT_EQ(EXIT_SUCCESS,
84  CombineLangModel(unicharset, script_dir, version_str, output_dir,
85  lang2, pass_through_recoder, words, puncs, numbers,
86  lang_is_rtl, nullptr, nullptr));
87  // Init a trainer with it, and encode kTestString.
88  std::string traineddata2 =
89  file::JoinPath(output_dir, lang2, absl::StrCat(lang2, ".traineddata"));
90  LSTMTrainer trainer2;
91  trainer2.InitCharSet(traineddata2);
92  GenericVector<int> labels2;
93  EXPECT_TRUE(trainer2.EncodeString(kTestString, &labels2));
94  STRING test2_decoded = trainer2.DecodeLabels(labels2);
95  std::string test2_str(&test2_decoded[0], test2_decoded.length());
96  LOG(INFO) << "Labels2=" << test2_str << "\n";
97  // encode kTestStringRupees.
98  GenericVector<int> labels3;
99  EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels3));
100  STRING test3_decoded = trainer2.DecodeLabels(labels3);
101  std::string test3_str(&test3_decoded[0], test3_decoded.length());
102  LOG(INFO) << "labels3=" << test3_str << "\n";
103  // Copy labels1 to a std::vector, renumbering the null char to match trainer2.
104  // Since Tensor Flow's CTC implementation insists on having the null be the
105  // last label, and we want to be compatible, null has to be renumbered when
106  // we add a class.
107  int null1 = trainer1.null_char();
108  int null2 = trainer2.null_char();
109  EXPECT_EQ(null1 + 1, null2);
110  std::vector<int> labels1_v(labels1.size());
111  for (int i = 0; i < labels1.size(); ++i) {
112  if (labels1[i] == null1)
113  labels1_v[i] = null2;
114  else
115  labels1_v[i] = labels1[i];
116  }
117  EXPECT_THAT(labels1_v,
118  testing::ElementsAreArray(&labels2[0], labels2.size()));
119  // To make sure we we are not cheating somehow, we can now encode the Rupee
120  // symbol, which we could not do before.
121  EXPECT_FALSE(trainer1.EncodeString(kTestStringRupees, &labels1));
122  EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels2));
123 }
124 
125 // Same as above test, for hin instead of eng
126 TEST(LangModelTest, AddACharacterHindi) {
127  constexpr char kTestString[] = "हिन्दी में एक लाइन लिखें";
128  constexpr char kTestStringRupees[] = "हिंदी में रूपये का चिन्ह प्रयोग करें ₹१००.००";
129  // Setup the arguments.
130  std::string script_dir = LANGDATA_DIR;
131  std::string hin_dir = file::JoinPath(script_dir, "hin");
132  std::string unicharset_path = TestDataNameToPath("hin_beam.unicharset");
133  UNICHARSET unicharset;
134  EXPECT_TRUE(unicharset.load_from_file(unicharset_path.c_str()));
135  std::string version_str = "TestVersion";
136  std::string output_dir = FLAGS_test_tmpdir;
137  LOG(INFO) << "Output dir=" << output_dir << "\n";
138  std::string lang1 = "hin";
139  bool pass_through_recoder = false;
140  GenericVector<STRING> words, puncs, numbers;
141  // If these reads fail, we get a warning message and an empty list of words.
142  ReadFile(file::JoinPath(hin_dir, "hin.wordlist"), nullptr)
143  .split('\n', &words);
144  EXPECT_GT(words.size(), 0);
145  ReadFile(file::JoinPath(hin_dir, "hin.punc"), nullptr).split('\n', &puncs);
146  EXPECT_GT(puncs.size(), 0);
147  ReadFile(file::JoinPath(hin_dir, "hin.numbers"), nullptr)
148  .split('\n', &numbers);
149  EXPECT_GT(numbers.size(), 0);
150  bool lang_is_rtl = false;
151  // Generate the traineddata file.
152  EXPECT_EQ(0, CombineLangModel(unicharset, script_dir, version_str, output_dir,
153  lang1, pass_through_recoder, words, puncs,
154  numbers, lang_is_rtl, nullptr, nullptr));
155  // Init a trainer with it, and encode kTestString.
156  std::string traineddata1 =
157  file::JoinPath(output_dir, lang1, absl::StrCat(lang1, ".traineddata"));
158  LSTMTrainer trainer1;
159  trainer1.InitCharSet(traineddata1);
160  GenericVector<int> labels1;
161  EXPECT_TRUE(trainer1.EncodeString(kTestString, &labels1));
162  STRING test1_decoded = trainer1.DecodeLabels(labels1);
163  std::string test1_str(&test1_decoded[0], test1_decoded.length());
164  LOG(INFO) << "Labels1=" << test1_str << "\n";
165 
166  // Add a new character to the unicharset and try again.
167  int size_before = unicharset.size();
168  unicharset.unichar_insert("₹");
169  SetupBasicProperties(/*report_errors*/ true, /*decompose (NFD)*/ false,
170  &unicharset);
171  EXPECT_EQ(size_before + 1, unicharset.size());
172  // Generate the traineddata file.
173  std::string lang2 = "extendedhin";
174  EXPECT_EQ(EXIT_SUCCESS,
175  CombineLangModel(unicharset, script_dir, version_str, output_dir,
176  lang2, pass_through_recoder, words, puncs, numbers,
177  lang_is_rtl, nullptr, nullptr));
178  // Init a trainer with it, and encode kTestString.
179  std::string traineddata2 =
180  file::JoinPath(output_dir, lang2, absl::StrCat(lang2, ".traineddata"));
181  LSTMTrainer trainer2;
182  trainer2.InitCharSet(traineddata2);
183  GenericVector<int> labels2;
184  EXPECT_TRUE(trainer2.EncodeString(kTestString, &labels2));
185  STRING test2_decoded = trainer2.DecodeLabels(labels2);
186  std::string test2_str(&test2_decoded[0], test2_decoded.length());
187  LOG(INFO) << "Labels2=" << test2_str << "\n";
188  // encode kTestStringRupees.
189  GenericVector<int> labels3;
190  EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels3));
191  STRING test3_decoded = trainer2.DecodeLabels(labels3);
192  std::string test3_str(&test3_decoded[0], test3_decoded.length());
193  LOG(INFO) << "labels3=" << test3_str << "\n";
194  // Copy labels1 to a std::vector, renumbering the null char to match trainer2.
195  // Since Tensor Flow's CTC implementation insists on having the null be the
196  // last label, and we want to be compatible, null has to be renumbered when
197  // we add a class.
198  int null1 = trainer1.null_char();
199  int null2 = trainer2.null_char();
200  EXPECT_EQ(null1 + 1, null2);
201  std::vector<int> labels1_v(labels1.size());
202  for (int i = 0; i < labels1.size(); ++i) {
203  if (labels1[i] == null1)
204  labels1_v[i] = null2;
205  else
206  labels1_v[i] = labels1[i];
207  }
208  EXPECT_THAT(labels1_v,
209  testing::ElementsAreArray(&labels2[0], labels2.size()));
210  // To make sure we we are not cheating somehow, we can now encode the Rupee
211  // symbol, which we could not do before.
212  EXPECT_FALSE(trainer1.EncodeString(kTestStringRupees, &labels1));
213  EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels2));
214 }
215 
216 } // namespace
217 } // namespace tesseract
UNICHARSET::load_from_file
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:378
file::JoinPath
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:43
string
std::string string
Definition: equationdetect_test.cc:21
INFO
Definition: log.h:29
tesseract::CombineLangModel
int CombineLangModel(const UNICHARSET &unicharset, const std::string &script_dir, const std::string &version_str, const std::string &output_dir, const std::string &lang, bool pass_through_recoder, const GenericVector< STRING > &words, const GenericVector< STRING > &puncs, const GenericVector< STRING > &numbers, bool lang_is_rtl, FileReader reader, FileWriter writer)
Definition: lang_model_helpers.cpp:185
unicharset_training_utils.h
lstmtrainer.h
STRING
Definition: strngs.h:45
tesseract::SetupBasicProperties
void SetupBasicProperties(bool report_errors, bool decompose, UNICHARSET *unicharset)
Definition: unicharset_training_utils.cpp:40
include_gunit.h
lang_model_helpers.h
FLAGS_test_tmpdir
const char * FLAGS_test_tmpdir
Definition: include_gunit.h:20
UNICHARSET
Definition: unicharset.h:145
tesseract
Definition: baseapi.h:65
GenericVector< STRING >
tesseract::ReadFile
STRING ReadFile(const std::string &filename, FileReader reader)
Definition: lang_model_helpers.cpp:57
log.h
LOG
Definition: cleanapi_test.cc:19
GenericVector::size
int size() const
Definition: genericvector.h:71
UNICHARSET::unichar_insert
void unichar_insert(const char *const unichar_repr, OldUncleanUnichars old_style)
Definition: unicharset.cpp:625
UNICHARSET::size
int size() const
Definition: unicharset.h:341
STRING::split
void split(char c, GenericVector< STRING > *splited)
Definition: strngs.cpp:275