14 #include "absl/strings/str_cat.h"
16 #include "gmock/gmock.h"
35 TEST(LangModelTest, AddACharacter) {
36 constexpr
char kTestString[] =
"Simple ASCII string to encode !@#$%&";
37 constexpr
char kTestStringRupees[] =
"ASCII string with Rupee symbol ₹";
41 std::string unicharset_path = TestDataNameToPath(
"eng_beam.unicharset");
46 LOG(
INFO) <<
"Output dir=" << output_dir <<
"\n";
48 bool pass_through_recoder =
false;
53 EXPECT_GT(words.
size(), 0);
55 EXPECT_GT(puncs.
size(), 0);
57 .
split(
'\n', &numbers);
58 EXPECT_GT(numbers.
size(), 0);
59 bool lang_is_rtl =
false;
61 EXPECT_EQ(0,
CombineLangModel(unicharset, script_dir, version_str, output_dir,
62 lang1, pass_through_recoder, words, puncs,
63 numbers, lang_is_rtl,
nullptr,
nullptr));
66 file::JoinPath(output_dir, lang1, absl::StrCat(lang1,
".traineddata"));
68 trainer1.InitCharSet(traineddata1);
70 EXPECT_TRUE(trainer1.EncodeString(kTestString, &labels1));
71 STRING test1_decoded = trainer1.DecodeLabels(labels1);
72 std::string test1_str(&test1_decoded[0], test1_decoded.length());
73 LOG(
INFO) <<
"Labels1=" << test1_str <<
"\n";
76 int size_before = unicharset.
size();
80 EXPECT_EQ(size_before + 1, unicharset.
size());
83 EXPECT_EQ(EXIT_SUCCESS,
85 lang2, pass_through_recoder, words, puncs, numbers,
86 lang_is_rtl,
nullptr,
nullptr));
89 file::JoinPath(output_dir, lang2, absl::StrCat(lang2,
".traineddata"));
91 trainer2.InitCharSet(traineddata2);
93 EXPECT_TRUE(trainer2.EncodeString(kTestString, &labels2));
94 STRING test2_decoded = trainer2.DecodeLabels(labels2);
95 std::string test2_str(&test2_decoded[0], test2_decoded.length());
96 LOG(
INFO) <<
"Labels2=" << test2_str <<
"\n";
99 EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels3));
100 STRING test3_decoded = trainer2.DecodeLabels(labels3);
101 std::string test3_str(&test3_decoded[0], test3_decoded.length());
102 LOG(
INFO) <<
"labels3=" << test3_str <<
"\n";
107 int null1 = trainer1.null_char();
108 int null2 = trainer2.null_char();
109 EXPECT_EQ(null1 + 1, null2);
110 std::vector<int> labels1_v(labels1.
size());
111 for (
int i = 0; i < labels1.
size(); ++i) {
112 if (labels1[i] == null1)
113 labels1_v[i] = null2;
115 labels1_v[i] = labels1[i];
117 EXPECT_THAT(labels1_v,
118 testing::ElementsAreArray(&labels2[0], labels2.
size()));
121 EXPECT_FALSE(trainer1.EncodeString(kTestStringRupees, &labels1));
122 EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels2));
126 TEST(LangModelTest, AddACharacterHindi) {
127 constexpr
char kTestString[] =
"हिन्दी में एक लाइन लिखें";
128 constexpr
char kTestStringRupees[] =
"हिंदी में रूपये का चिन्ह प्रयोग करें ₹१००.००";
132 std::string unicharset_path = TestDataNameToPath(
"hin_beam.unicharset");
137 LOG(
INFO) <<
"Output dir=" << output_dir <<
"\n";
139 bool pass_through_recoder =
false;
143 .
split(
'\n', &words);
144 EXPECT_GT(words.
size(), 0);
146 EXPECT_GT(puncs.
size(), 0);
148 .
split(
'\n', &numbers);
149 EXPECT_GT(numbers.
size(), 0);
150 bool lang_is_rtl =
false;
152 EXPECT_EQ(0,
CombineLangModel(unicharset, script_dir, version_str, output_dir,
153 lang1, pass_through_recoder, words, puncs,
154 numbers, lang_is_rtl,
nullptr,
nullptr));
157 file::JoinPath(output_dir, lang1, absl::StrCat(lang1,
".traineddata"));
158 LSTMTrainer trainer1;
159 trainer1.InitCharSet(traineddata1);
161 EXPECT_TRUE(trainer1.EncodeString(kTestString, &labels1));
162 STRING test1_decoded = trainer1.DecodeLabels(labels1);
163 std::string test1_str(&test1_decoded[0], test1_decoded.length());
164 LOG(
INFO) <<
"Labels1=" << test1_str <<
"\n";
167 int size_before = unicharset.
size();
171 EXPECT_EQ(size_before + 1, unicharset.
size());
174 EXPECT_EQ(EXIT_SUCCESS,
176 lang2, pass_through_recoder, words, puncs, numbers,
177 lang_is_rtl,
nullptr,
nullptr));
180 file::JoinPath(output_dir, lang2, absl::StrCat(lang2,
".traineddata"));
181 LSTMTrainer trainer2;
182 trainer2.InitCharSet(traineddata2);
184 EXPECT_TRUE(trainer2.EncodeString(kTestString, &labels2));
185 STRING test2_decoded = trainer2.DecodeLabels(labels2);
186 std::string test2_str(&test2_decoded[0], test2_decoded.length());
187 LOG(
INFO) <<
"Labels2=" << test2_str <<
"\n";
190 EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels3));
191 STRING test3_decoded = trainer2.DecodeLabels(labels3);
192 std::string test3_str(&test3_decoded[0], test3_decoded.length());
193 LOG(
INFO) <<
"labels3=" << test3_str <<
"\n";
198 int null1 = trainer1.null_char();
199 int null2 = trainer2.null_char();
200 EXPECT_EQ(null1 + 1, null2);
201 std::vector<int> labels1_v(labels1.
size());
202 for (
int i = 0; i < labels1.
size(); ++i) {
203 if (labels1[i] == null1)
204 labels1_v[i] = null2;
206 labels1_v[i] = labels1[i];
208 EXPECT_THAT(labels1_v,
209 testing::ElementsAreArray(&labels2[0], labels2.
size()));
212 EXPECT_FALSE(trainer1.EncodeString(kTestStringRupees, &labels1));
213 EXPECT_TRUE(trainer2.EncodeString(kTestStringRupees, &labels2));