tesseract  5.0.0-alpha-619-ge9db
lstmtrainer_test.cc
Go to the documentation of this file.
1 // (C) Copyright 2017, Google Inc.
2 // Licensed under the Apache License, Version 2.0 (the "License");
3 // you may not use this file except in compliance with the License.
4 // You may obtain a copy of the License at
5 // http://www.apache.org/licenses/LICENSE-2.0
6 // Unless required by applicable law or agreed to in writing, software
7 // distributed under the License is distributed on an "AS IS" BASIS,
8 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 // See the License for the specific language governing permissions and
10 // limitations under the License.
11 
12 #include "allheaders.h"
13 #include <tesseract/baseapi.h>
14 #include "lstm_test.h"
15 
16 namespace tesseract {
17 namespace {
18 
19 TEST_F(LSTMTrainerTest, EncodesEng) {
20  TestEncodeDecodeBoth("eng",
21  "The quick brown 'fox' jumps over: the lazy dog!");
22 }
23 
24 TEST_F(LSTMTrainerTest, EncodesKan) {
25  TestEncodeDecodeBoth("kan", "ಫ್ರಬ್ರವರಿ ತತ್ವಾಂಶಗಳೆಂದರೆ ಮತ್ತು ಜೊತೆಗೆ ಕ್ರಮವನ್ನು");
26 }
27 
28 TEST_F(LSTMTrainerTest, EncodesKor) {
29  TestEncodeDecodeBoth("kor",
30  "이는 것으로 다시 넣을 수는 있지만 선택의 의미는");
31 }
32 
33 TEST_F(LSTMTrainerTest, MapCoder) {
34  LSTMTrainer fra_trainer;
35  fra_trainer.InitCharSet(TestDataNameToPath("fra/fra.traineddata"));
36  LSTMTrainer deu_trainer;
37  deu_trainer.InitCharSet(TestDataNameToPath("deu/deu.traineddata"));
38  // A string that uses characters common to French and German.
39  std::string kTestStr = "The quick brown 'fox' jumps over: the lazy dog!";
40  GenericVector<int> deu_labels;
41  EXPECT_TRUE(deu_trainer.EncodeString(kTestStr.c_str(), &deu_labels));
42  // The french trainer cannot decode them correctly.
43  STRING badly_decoded = fra_trainer.DecodeLabels(deu_labels);
44  std::string bad_str(&badly_decoded[0], badly_decoded.length());
45  LOG(INFO) << "bad_str fra=" << bad_str << "\n";
46  EXPECT_NE(kTestStr, bad_str);
47  // Encode the string as fra.
48  GenericVector<int> fra_labels;
49  EXPECT_TRUE(fra_trainer.EncodeString(kTestStr.c_str(), &fra_labels));
50  // Use the mapper to compute what the labels are as deu.
51  std::vector<int> mapping = fra_trainer.MapRecoder(deu_trainer.GetUnicharset(),
52  deu_trainer.GetRecoder());
53  GenericVector<int> mapped_fra_labels(fra_labels.size(), -1);
54  for (int i = 0; i < fra_labels.size(); ++i) {
55  mapped_fra_labels[i] = mapping[fra_labels[i]];
56  EXPECT_NE(-1, mapped_fra_labels[i]) << "i=" << i << ", ch=" << kTestStr[i];
57  EXPECT_EQ(mapped_fra_labels[i], deu_labels[i])
58  << "i=" << i << ", ch=" << kTestStr[i]
59  << " has deu label=" << deu_labels[i] << ", but mapped to "
60  << mapped_fra_labels[i];
61  }
62  // The german trainer can now decode them correctly.
63  STRING decoded = deu_trainer.DecodeLabels(mapped_fra_labels);
64  std::string ok_str(&decoded[0], decoded.length());
65  LOG(INFO) << "ok_str deu=" << ok_str << "\n";
66  EXPECT_EQ(kTestStr, ok_str);
67 }
68 
69 // Tests that the actual fra model can be converted to the deu character set
70 // and still read an eng image with 100% accuracy.
71 TEST_F(LSTMTrainerTest, ConvertModel) {
72  // Setup a trainer with a deu charset.
73  LSTMTrainer deu_trainer;
74  deu_trainer.InitCharSet(TestDataNameToPath("deu/deu.traineddata"));
75  // Load the fra traineddata, strip out the model, and save to a tmp file.
76  TessdataManager mgr;
77  std::string fra_data =
78  file::JoinPath(TESSDATA_BEST_DIR, "fra.traineddata");
79  CHECK(mgr.Init(fra_data.c_str()));
80  LOG(INFO) << "Load " << fra_data << "\n";
81  std::string model_path = file::JoinPath(FLAGS_test_tmpdir, "fra.lstm");
82  CHECK(mgr.ExtractToFile(model_path.c_str()));
83  LOG(INFO) << "Extract " << model_path << "\n";
84  // Load the fra model into the deu_trainer, and save the converted model.
85  CHECK(deu_trainer.TryLoadingCheckpoint(model_path.c_str(), fra_data.c_str()));
86  LOG(INFO) << "Checkpoint load for " << model_path << " and " << fra_data << "\n";
87  std::string deu_data = file::JoinPath(FLAGS_test_tmpdir, "deu.traineddata");
88  CHECK(deu_trainer.SaveTraineddata(deu_data.c_str()));
89  LOG(INFO) << "Save " << deu_data << "\n";
90  // Now run the saved model on phototest. (See BasicTesseractTest in
91  // baseapi_test.cc).
92  TessBaseAPI api;
94  Pix* src_pix = pixRead(TestingNameToPath("phototest.tif").c_str());
95  CHECK(src_pix);
96  api.SetImage(src_pix);
97  std::unique_ptr<char[]> result(api.GetUTF8Text());
98  std::string truth_text;
99  CHECK_OK(file::GetContents(TestingNameToPath("phototest.gold.txt"),
100  &truth_text, file::Defaults()));
101 
102  EXPECT_STREQ(truth_text.c_str(), result.get());
103  pixDestroy(&src_pix);
104 }
105 
106 } // namespace
107 } // namespace tesseract
file::JoinPath
static std::string JoinPath(const std::string &s1, const std::string &s2)
Definition: include_gunit.h:43
string
std::string string
Definition: equationdetect_test.cc:21
INFO
Definition: log.h:29
CHECK_OK
#define CHECK_OK(test)
Definition: include_gunit.h:62
STRING
Definition: strngs.h:45
file::Defaults
static int Defaults()
Definition: include_gunit.h:39
tesseract::OEM_LSTM_ONLY
Definition: publictypes.h:267
tesseract::TEST_F
TEST_F(EquationFinderTest, IdentifySpecialText)
Definition: equationdetect_test.cc:181
file::GetContents
static bool GetContents(const std::string &filename, std::string *out, int)
Definition: include_gunit.h:31
CHECK
#define CHECK(test)
Definition: include_gunit.h:57
baseapi.h
FLAGS_test_tmpdir
const char * FLAGS_test_tmpdir
Definition: include_gunit.h:20
tesseract
Definition: baseapi.h:65
lstm_test.h
GenericVector< int >
STRING::length
int32_t length() const
Definition: strngs.cpp:187
LOG
Definition: cleanapi_test.cc:19
TessBaseAPI
struct TessBaseAPI TessBaseAPI
Definition: capi.h:72
GenericVector::size
int size() const
Definition: genericvector.h:71