24 #include "absl/strings/str_format.h"
38 const int kNumChars = 100;
40 const int kPadding = 64;
49 const char* kGWRTops[] = {
"G",
"e",
"f",
" ",
"s",
" ",
"w",
"o",
"r",
"d",
50 "s",
"",
"r",
"i",
"g",
"h",
"t",
".",
nullptr};
51 const float kGWRTopScores[] = {0.99, 0.85, 0.87, 0.55, 0.99, 0.65,
52 0.89, 0.99, 0.99, 0.99, 0.99, 0.95,
53 0.99, 0.90, 0.90, 0.90, 0.95, 0.75};
54 const char* kGWR2nds[] = {
"C",
"c",
"t",
"",
"S",
"",
"W",
"O",
"t",
"h",
55 "S",
" ",
"t",
"I",
"9",
"b",
"f",
",",
nullptr};
56 const float kGWR2ndScores[] = {0.01, 0.10, 0.12, 0.42, 0.01, 0.25,
57 0.10, 0.01, 0.01, 0.01, 0.01, 0.05,
58 0.01, 0.09, 0.09, 0.09, 0.05, 0.25};
60 const char* kZHTops[] = {
"实",
"学",
"储",
"啬",
"投",
"学",
"生",
nullptr};
61 const float kZHTopScores[] = {0.98, 0.98, 0.98, 0.98, 0.98, 0.98, 0.98};
62 const char* kZH2nds[] = {
"学",
"储",
"投",
"生",
"学",
"生",
"实",
nullptr};
63 const float kZH2ndScores[] = {0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01};
65 const char* kViTops[] = {
"v",
"ậ",
"y",
" ",
"t",
"ộ",
"i",
nullptr};
66 const float kViTopScores[] = {0.98, 0.98, 0.98, 0.98, 0.98, 0.98, 0.97};
67 const char* kVi2nds[] = {
"V",
"a",
"v",
"",
"l",
"o",
"",
nullptr};
68 const float kVi2ndScores[] = {0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01};
70 class RecodeBeamTest :
public ::testing::Test {
73 std::locale::global(std::locale(
""));
76 RecodeBeamTest() : lstm_dict_(&ccutil_) {}
77 ~RecodeBeamTest() { lstm_dict_.End(); }
80 void LoadUnicharset(
const std::string& unicharset_name) {
82 "radical-stroke.txt");
88 CHECK(ccutil_.unicharset.load_from_file(unicharset_file.c_str()));
89 unichar_null_char_ = ccutil_.unicharset.has_special_codes()
91 : ccutil_.unicharset.size();
92 STRING radical_str(radical_data.c_str());
93 EXPECT_TRUE(recoder_.ComputeEncoding(ccutil_.unicharset, unichar_null_char_,
96 recoder_.EncodeUnichar(unichar_null_char_, &code);
102 STRING encoding = recoder_.GetEncodingAsString(ccutil_.unicharset);
105 LOG(
INFO) <<
"Wrote encoding to:" << output_name <<
"\n";
109 std::string traineddata_name = lang +
".traineddata";
112 lstm_dict_.SetupForLoad(
nullptr);
114 mgr.
Init(traineddata_file.c_str());
115 lstm_dict_.LoadLSTM(lang.c_str(), &mgr);
116 lstm_dict_.FinishLoad();
124 for (
int i = 0; i < transcription.
size(); ++i) {
125 truth_utf8 += ccutil_.unicharset.id_to_unichar(transcription[i]);
128 ExpectCorrect(output, truth_utf8,
nullptr, &words);
134 beam_search.Decode(output, 3.5, -0.125, -25.0,
nullptr);
138 beam_search.ExtractBestPathAsLabels(&labels, &xcoords);
139 LOG(
INFO) <<
"Labels size = " << labels.
size() <<
" coords "
140 << xcoords.
size() <<
"\n";
144 for (
int start = 0; start < labels.
size(); start = end) {
147 int uni_id = INVALID_UNICHAR_ID;
149 code.
Set(code.
length(), labels[index++]);
150 uni_id = recoder_.DecodeUnichar(code);
151 }
while (index < labels.
size() &&
152 code.
length() < RecodedCharID::kMaxCodeLen &&
153 (uni_id == INVALID_UNICHAR_ID ||
154 !recoder_.IsValidFirstCode(labels[index])));
155 EXPECT_NE(INVALID_UNICHAR_ID, uni_id)
156 <<
"index=" << index <<
"/" << labels.
size();
160 if (uni_id != unichar_null_char_ && decoded.size() < truth_utf8.size())
161 decoded += ccutil_.unicharset.id_to_unichar(uni_id);
164 EXPECT_EQ(truth_utf8, decoded);
169 beam_search.ExtractBestPathAsUnicharIds(
false, &ccutil_.unicharset,
170 &unichar_ids, &certainties,
173 float total_rating = 0.0f;
174 for (
int u = 0; u < unichar_ids.
size(); ++u) {
178 if (u_decoded.size() < truth_utf8.size()) {
179 const char* str = ccutil_.unicharset.id_to_unichar(unichar_ids[u]);
180 total_rating += ratings[u];
181 LOG(
INFO) << absl::StrFormat(
"%d:u_id=%d=%s, c=%g, r=%g, r_sum=%g @%d", u,
182 unichar_ids[u], str, certainties[u],
183 ratings[u], total_rating, xcoords[u]) <<
"\n";
184 if (str[0] ==
' ') total_rating = 0.0f;
188 EXPECT_EQ(truth_utf8, u_decoded);
191 TBOX line_box(0, 0, 100, 10);
192 for (
int i = 0; i < 2; ++i) {
193 beam_search.ExtractBestPathAsWords(line_box, 1.0f,
false,
194 &ccutil_.unicharset, words);
196 for (
int w = 0; w < words->
size(); ++w) {
198 if (w_decoded.size() < truth_utf8.size()) {
199 if (!w_decoded.empty() && word->
word->
space()) w_decoded +=
" ";
202 LOG(
INFO) << absl::StrFormat(
"Word:%d = %s, c=%g, r=%g, perm=%d", w,
208 std::string w_trunc(w_decoded.data(), truth_utf8.size());
209 if (truth_utf8 != w_trunc) {
213 w_trunc.assign(w_decoded.data(), truth_utf8.size());
215 EXPECT_EQ(truth_utf8, w_trunc);
222 int width = unichar_ids.
size() * 2 * RecodedCharID::kMaxCodeLen;
223 int num_codes = recoder_.code_range();
227 for (
int t = 0; t < width; ++t) {
228 for (
int i = 0; i < num_codes; ++i)
232 for (
int i = 0; i < unichar_ids.
size(); ++i) {
234 int len = recoder_.EncodeUnichar(unichar_ids[i], &code);
236 for (
int j = 0; j < len; ++j) {
238 if (j > 0 && code(j) == code(j - 1)) {
242 outputs(t++, code(j)) = 1.0f;
248 for (
int t = 0; t < width; ++t) {
250 for (
int i = 0; i < num_codes; ++i) sum += outputs(t, i);
251 for (
int i = 0; i < num_codes; ++i) outputs(t, i) /= sum;
258 int EncodeUTF8(
const char* utf8_str,
float score,
int start_t,
TRand* random,
262 EXPECT_TRUE(ccutil_.unicharset.encode_string(utf8_str,
true, &unichar_ids,
264 if (unichar_ids.
empty() || utf8_str[0] ==
'\0') {
266 unichar_ids.
push_back(unichar_null_char_);
268 int num_ids = unichar_ids.
size();
269 for (
int u = 0; u < num_ids; ++u) {
271 int len = recoder_.EncodeUnichar(unichar_ids[u], &code);
273 for (
int i = 0; i < len; ++i) {
275 (*outputs)(t++, code(i)) = score;
276 if (random !=
nullptr &&
277 t + (num_ids - u) * RecodedCharID::kMaxCodeLen < outputs->
dim1()) {
279 for (
int d = 0; d < dups; ++d) {
281 (*outputs)(t++, code(i)) = score;
285 if (random !=
nullptr &&
286 t + (num_ids - u) * RecodedCharID::kMaxCodeLen < outputs->
dim1()) {
288 for (
int d = 0; d < dups; ++d) {
301 const float scores1[],
302 const char* chars2[],
303 const float scores2[],
306 while (chars1[width] !=
nullptr) ++width;
307 int padding = width * RecodedCharID::kMaxCodeLen;
308 int num_codes = recoder_.code_range();
311 for (
int i = 0; i < width; ++i) {
314 int end_t2 = EncodeUTF8(chars2[i], scores2[i], t, random, &outputs);
315 int end_t1 = EncodeUTF8(chars1[i], scores1[i], t, random, &outputs);
317 int max_t = std::max(end_t1, end_t2);
318 int min_t = std::min(end_t1, end_t2);
320 double total_score = 0.0;
321 for (
int j = 0; j < num_codes; ++j) total_score += outputs(t, j);
322 double null_remainder = (1.0 - total_score) / 2.0;
323 double remainder = null_remainder / (num_codes - 2);
327 remainder += remainder;
329 for (
int j = 0; j < num_codes; ++j) {
330 if (outputs(t, j) == 0.0f) outputs(t, j) = remainder;
336 while (t < width + padding) {
342 int unichar_null_char_ = 0;
348 TEST_F(RecodeBeamTest, DoesChinese) {
349 LOG(
INFO) <<
"Testing chi_tra" <<
"\n";
350 LoadUnicharset(
"chi_tra.unicharset");
356 GenerateRandomPaddedOutputs(transcription, kPadding);
357 ExpectCorrect(outputs, transcription);
358 LOG(
INFO) <<
"Testing chi_sim" <<
"\n";
359 LoadUnicharset(
"chi_sim.unicharset");
361 transcription.
clear();
364 outputs = GenerateRandomPaddedOutputs(transcription, kPadding);
365 ExpectCorrect(outputs, transcription);
368 TEST_F(RecodeBeamTest, DoesJapanese) {
369 LOG(
INFO) <<
"Testing jpn" <<
"\n";
370 LoadUnicharset(
"jpn.unicharset");
376 GenerateRandomPaddedOutputs(transcription, kPadding);
377 ExpectCorrect(outputs, transcription);
380 TEST_F(RecodeBeamTest, DoesKorean) {
381 LOG(
INFO) <<
"Testing kor" <<
"\n";
382 LoadUnicharset(
"kor.unicharset");
388 GenerateRandomPaddedOutputs(transcription, kPadding);
389 ExpectCorrect(outputs, transcription);
392 TEST_F(RecodeBeamTest, DoesKannada) {
393 LOG(
INFO) <<
"Testing kan" <<
"\n";
394 LoadUnicharset(
"kan.unicharset");
400 GenerateRandomPaddedOutputs(transcription, kPadding);
401 ExpectCorrect(outputs, transcription);
404 TEST_F(RecodeBeamTest, DoesMarathi) {
405 LOG(
INFO) <<
"Testing mar" <<
"\n";
406 LoadUnicharset(
"mar.unicharset");
412 GenerateRandomPaddedOutputs(transcription, kPadding);
413 ExpectCorrect(outputs, transcription);
416 TEST_F(RecodeBeamTest, DoesEnglish) {
417 LOG(
INFO) <<
"Testing eng" <<
"\n";
418 LoadUnicharset(
"eng.unicharset");
424 GenerateRandomPaddedOutputs(transcription, kPadding);
425 ExpectCorrect(outputs, transcription);
428 TEST_F(RecodeBeamTest, DISABLED_EngDictionary) {
429 LOG(
INFO) <<
"Testing eng dictionary" <<
"\n";
430 LoadUnicharset(
"eng_beam.unicharset");
432 kGWRTops, kGWRTopScores, kGWR2nds, kGWR2ndScores,
nullptr);
434 for (
int i = 0; kGWRTops[i] !=
nullptr; ++i) default_str += kGWRTops[i];
436 ExpectCorrect(outputs, default_str,
nullptr, &words);
438 LoadDict(
"eng_beam");
439 ExpectCorrect(outputs,
"Gets words right.", &lstm_dict_, &words);
442 TEST_F(RecodeBeamTest, DISABLED_ChiDictionary) {
443 LOG(
INFO) <<
"Testing zh_hans dictionary" <<
"\n";
444 LoadUnicharset(
"zh_hans.unicharset");
446 kZHTops, kZHTopScores, kZH2nds, kZH2ndScores,
nullptr);
448 ExpectCorrect(outputs,
"实学储啬投学生",
nullptr, &words);
450 EXPECT_EQ(7, words.
size());
451 for (
int w = 0; w < words.
size(); ++w) {
456 ExpectCorrect(outputs,
"实学储啬投学生", &lstm_dict_, &words);
458 const int kNumWords = 5;
460 const char* kWords[kNumWords] = {
"实学",
"储",
"啬",
"投",
"学生"};
465 EXPECT_EQ(kNumWords, words.
size());
466 for (
int w = 0; w < kNumWords && w < words.
size(); ++w) {
467 EXPECT_STREQ(kWords[w], words[w]->best_choice->unichar_string().c_str());
468 EXPECT_EQ(kWordPerms[w], words[w]->best_choice->permuter());
474 TEST_F(RecodeBeamTest, DISABLED_MultiCodeSequences) {
475 LOG(
INFO) <<
"Testing duplicates in multi-code sequences" <<
"\n";
476 LoadUnicharset(
"vie.d.unicharset");
480 kViTops, kViTopScores, kVi2nds, kVi2ndScores, &random);
486 ExpectCorrect(outputs, truth_str,
nullptr, &words);