25 #include <unordered_set>
37 5, 10, 16, 16, 16, 16, 16, 16, 16, 16,
40 static const char* kNodeContNames[] = {
"Anything",
"OnlyDup",
"NoDup"};
45 if (
code == null_char) {
54 if (depth > 0 &&
prev !=
nullptr) {
56 prev->
Print(null_char, unicharset, depth - 1);
64 int null_char,
bool simple_text,
Dict* dict)
70 space_delimited_(true),
71 is_simple_text_(simple_text),
74 space_delimited_ =
false;
79 double cert_offset,
double worst_dict_cert,
80 const UNICHARSET* charset,
int lstm_choice_mode) {
82 int width = output.
Width();
84 for (
int t = 0; t < width; ++t) {
85 ComputeTopN(output.
f(t), output.
NumFeatures(), kBeamWidths[0]);
86 DecodeStep(output.
f(t), t, dict_ratio, cert_offset, worst_dict_cert,
88 if (lstm_choice_mode) {
89 SaveMostCertainChoices(output.
f(t), output.
NumFeatures(), charset, t);
94 double dict_ratio,
double cert_offset,
95 double worst_dict_cert,
98 int width = output.
dim1();
99 for (
int t = 0; t < width; ++t) {
100 ComputeTopN(output[t], output.
dim2(), kBeamWidths[0]);
101 DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
108 double worst_dict_cert,
110 int lstm_choice_mode) {
111 secondary_beam_.
clear();
113 int width = output.
Width();
114 int bucketNumber = 0;
115 for (
int t = 0; t < width; ++t) {
123 DecodeSecondaryStep(output.
f(t), t, dict_ratio, cert_offset, worst_dict_cert,
128 void RecodeBeamSearch::SaveMostCertainChoices(
const float* outputs,
132 std::vector<std::pair<const char*, float>> choices;
133 for (
int i = 0; i < num_outputs; ++i) {
134 if (outputs[i] >= 0.01f) {
136 if (i + 2 >= num_outputs) {
146 while (choices.size() > pos && choices[pos].second > outputs[i]) {
149 choices.insert(choices.begin() + pos,
150 std::pair<const char*, float>(
character, outputs[i]));
158 std::vector<std::vector<std::pair<const char*, float>>> segment;
165 std::vector<std::vector<std::pair<const char*, float>>>
167 std::vector<std::vector<std::vector<std::pair<const char*, float>>>>*
168 segmentedTimesteps) {
169 std::vector<std::vector<std::pair<const char*, float>>> combined_timesteps;
171 for (
int j = 0; j < (*segmentedTimesteps)[i].size(); ++j) {
175 return combined_timesteps;
178 void RecodeBeamSearch::calculateCharBoundaries(std::vector<int>* starts,
179 std::vector<int>* ends,
180 std::vector<int>* char_bounds_,
182 char_bounds_->push_back(0);
183 for (
int i = 0; i < ends->size(); ++i) {
184 int middle = ((*starts)[i+1]-(*ends)[i])/2;
185 char_bounds_->push_back((*ends)[i] + middle);
187 char_bounds_->pop_back();
188 char_bounds_->push_back(maxWidth);
197 ExtractBestPaths(&best_nodes,
nullptr);
200 int width = best_nodes.
size();
202 int label = best_nodes[t]->code;
203 if (label != null_char_) {
207 while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) {
220 ExtractBestPaths(&best_nodes,
nullptr);
221 ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords);
223 DebugPath(unicharset, best_nodes);
224 DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings,
231 float scale_factor,
bool debug,
234 int lstm_choice_mode) {
243 ExtractBestPaths(&best_nodes, &second_nodes);
245 DebugPath(unicharset, best_nodes);
246 ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings,
248 tprintf(
"\nSecond choice path:\n");
249 DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings,
255 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords,
257 int num_ids = unichar_ids.
size();
259 DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings,
264 float prev_space_cert = 0.0f;
265 for (
int word_start = 0; word_start < num_ids; word_start = word_end) {
266 for (word_end = word_start + 1; word_end < num_ids; ++word_end) {
271 int index = xcoords[word_end];
272 if (best_nodes[index]->start_of_word)
break;
278 float space_cert = 0.0f;
279 if (word_end < num_ids && unichar_ids[word_end] ==
UNICHAR_SPACE)
280 space_cert = certs[word_end];
282 word_start > 0 && unichar_ids[word_start - 1] ==
UNICHAR_SPACE;
285 InitializeWord(leading_space, line_box, word_start, word_end,
286 std::min(space_cert, prev_space_cert), unicharset,
287 xcoords, scale_factor);
288 for (
int i = word_start; i < word_end; ++i) {
289 auto* choices =
new BLOB_CHOICE_LIST;
290 BLOB_CHOICE_IT bc_it(choices);
292 unichar_ids[i], ratings[i], certs[i], -1, 1.0f,
294 int col = i - word_start;
295 choice->set_matrix_cell(col, col);
296 bc_it.add_after_then_move(choice);
299 int index = xcoords[word_end - 1];
302 prev_space_cert = space_cert;
303 if (word_end < num_ids && unichar_ids[word_end] ==
UNICHAR_SPACE)
317 bool secondary)
const {
318 std::vector<std::vector<const RecodeNode*>> topology;
319 std::unordered_set<const RecodeNode*> visited;
322 for (
int step = beam->
size()-1; step >=0; --step) {
323 std::vector<const RecodeNode*> layer;
324 topology.push_back(layer);
327 for (
int step = beam->
size() - 1; step >= 0; --step) {
329 beam->
get(step)->beams_->heap();
330 for (
int node = 0; node < heaps->
size(); ++node) {
333 while (curr !=
nullptr && !visited.count(curr)) {
334 visited.insert(curr);
335 topology[step - backtracker].push_back(curr);
343 for (std::vector<const RecodeNode*> layer: topology) {
353 if (node->unichar_id != INVALID_UNICHAR_ID) {
355 intCode = node->unichar_id;
356 }
else if(node->code == null_char_) {
364 const char* prevCode;
366 if (node->prev !=
nullptr) {
367 prevScore = node->prev->score;
368 if (node->prev->unichar_id != INVALID_UNICHAR_ID) {
370 intPrevCode = node->prev->unichar_id;
371 }
else if (node->code == null_char_) {
382 tprintf(
"%x(|)%f(>)%x(|)%f\n", intPrevCode,
383 prevScore, intCode, node->score);
385 tprintf(
"%s(|)%f(>)%s(|)%f\n", prevCode,
386 prevScore, code, node->score);
401 if (secondary_beam_.
empty()) {
402 currentBeam = &beam_;
404 currentBeam = &secondary_beam_;
415 std::vector<const RecodeNode*> best;
417 for (
int i = 0; i < heaps->
size(); ++i) {
418 bool validChar =
false;
421 while (node !=
nullptr && backcounter < backpath) {
422 if (node->
code != null_char_ && node->
unichar_id != INVALID_UNICHAR_ID) {
429 if (validChar) best.push_back(&heaps->
get(i).
data);
434 ExtractPath(best[0], &best_nodes, backpath);
435 ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings,
438 if (!unichar_ids.
empty()) {
440 for (
int i = 1; i < unichar_ids.
size(); ++i) {
441 if (ratings[i] < ratings[bestPos])
445 for (
int i = 0; i < best_nodes.
size(); ++i) {
446 if (best_nodes[i]->unichar_id == unichar_ids[bestPos]) {
447 bestCode = best_nodes[i]->code;
451 std::unordered_set<int> excludeCodeList;
452 for (
int node = 0; node < best_nodes.
size(); ++node) {
453 if (best_nodes[node]->code != null_char_) {
454 excludeCodeList.insert(best_nodes[node]->code);
458 for (
auto elem : excludeCodeList) {
466 int id = unichar_ids[bestPos];
468 float rating = ratings[bestPos];
470 std::pair<const char*, float>(result, rating));
472 std::vector<std::pair<const char*, float>> choice;
473 int id = unichar_ids[bestPos];
475 float rating = ratings[bestPos];
476 choice.push_back(std::pair<const char*, float>(result, rating));
482 std::unordered_set<int> excludeCodeList;
486 std::vector<std::pair<const char*, float>> choice;
491 secondary_beam_.
clear();
496 for (
int p = 0; p < beam_size_; ++p) {
497 for (
int d = 0; d < 2; ++d) {
498 for (
int c = 0; c <
NC_COUNT; ++c) {
499 auto cont = static_cast<NodeContinuation>(c);
501 if (beam_[p]->beams_[index].empty())
continue;
503 tprintf(
"Position %d: %s+%s beam\n", p, d ?
"Dict" :
"Non-Dict",
505 DebugBeamPos(unicharset, beam_[p]->beams_[index]);
512 void RecodeBeamSearch::DebugBeamPos(
const UNICHARSET& unicharset,
517 int heap_size = heap.
size();
518 for (
int i = 0; i < heap_size; ++i) {
521 if (null_best ==
nullptr || null_best->
score < node->
score)
524 if (unichar_bests[node->
unichar_id] ==
nullptr ||
530 for (
int u = 0; u < unichar_bests.
size(); ++u) {
531 if (unichar_bests[u] !=
nullptr) {
532 const RecodeNode& node = *unichar_bests[u];
533 node.Print(null_char_, unicharset, 1);
536 if (null_best !=
nullptr) {
537 null_best->
Print(null_char_, unicharset, 1);
544 void RecodeBeamSearch::ExtractPathAsUnicharIds(
548 std::vector<int>* character_boundaries) {
553 std::vector<int> starts;
554 std::vector<int> ends;
557 int width = best_nodes.
size();
559 double certainty = 0.0;
561 while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) {
562 double cert = best_nodes[t++]->certainty;
563 if (cert < certainty) certainty = cert;
568 int unichar_id = best_nodes[t]->unichar_id;
570 best_nodes[t]->permuter !=
NO_PERM) {
573 if (certainty < certs->back()) certs->
back() = certainty;
574 ratings->
back() += rating;
581 double cert = best_nodes[t++]->certainty;
585 best_nodes[t - 1]->permuter ==
NO_PERM)) {
589 }
while (t < width && best_nodes[t]->duplicate);
593 }
else if (!certs->
empty()) {
594 if (certainty < certs->back()) certs->
back() = certainty;
595 ratings->
back() += rating;
598 starts.push_back(width);
599 if (character_boundaries !=
nullptr) {
600 calculateCharBoundaries(&starts, &ends, character_boundaries, width);
607 WERD_RES* RecodeBeamSearch::InitializeWord(
bool leading_space,
608 const TBOX& line_box,
int word_start,
609 int word_end,
float space_certainty,
612 float scale_factor) {
615 C_BLOB_IT b_it(&blobs);
616 for (
int i = word_start; i < word_end; ++i) {
620 box.
scale(scale_factor);
622 box.set_top(line_box.
top());
627 WERD* word =
new WERD(&blobs, leading_space,
nullptr);
629 auto* word_res =
new WERD_RES(word);
630 word_res->end = word_end - word_start + leading_space;
631 word_res->uch_set = unicharset;
632 word_res->combination =
true;
633 word_res->space_certainty = space_certainty;
634 word_res->ratings =
new MATRIX(word_end - word_start, 1);
640 void RecodeBeamSearch::ComputeTopN(
const float* outputs,
int num_outputs,
646 for (
int i = 0; i < num_outputs; ++i) {
647 if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key) {
648 TopPair entry(outputs[i], i);
649 top_heap_.Push(&entry);
650 if (top_heap_.size() > top_n) top_heap_.Pop(&entry);
653 while (!top_heap_.empty()) {
655 top_heap_.Pop(&entry);
656 if (top_heap_.size() > 1) {
657 top_n_flags_[entry.data] =
TN_TOPN;
659 top_n_flags_[entry.data] =
TN_TOP2;
660 if (top_heap_.empty())
661 top_code_ = entry.data;
663 second_code_ = entry.data;
666 top_n_flags_[null_char_] =
TN_TOP2;
669 void RecodeBeamSearch::ComputeSecTopN(std::unordered_set<int>* exList,
670 const float* outputs,
int num_outputs,
676 for (
int i = 0; i < num_outputs; ++i) {
677 if ((top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key)
678 && !exList->count(i)) {
679 TopPair entry(outputs[i], i);
680 top_heap_.Push(&entry);
681 if (top_heap_.size() > top_n) top_heap_.Pop(&entry);
684 while (!top_heap_.empty()) {
686 top_heap_.Pop(&entry);
687 if (top_heap_.size() > 1) {
688 top_n_flags_[entry.data] =
TN_TOPN;
690 top_n_flags_[entry.data] =
TN_TOP2;
691 if (top_heap_.empty())
692 top_code_ = entry.data;
694 second_code_ = entry.data;
697 top_n_flags_[null_char_] =
TN_TOP2;
703 void RecodeBeamSearch::DecodeStep(
const float* outputs,
int t,
704 double dict_ratio,
double cert_offset,
705 double worst_dict_cert,
708 RecodeBeam* step = beam_[t];
714 charset, dict_ratio, cert_offset, worst_dict_cert, step);
715 if (dict_ !=
nullptr) {
717 charset, dict_ratio, cert_offset, worst_dict_cert, step);
720 RecodeBeam* prev = beam_[t - 1];
723 for (
int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
725 ExtractPath(&prev->beams_[beam_index].get(i).data, &path);
726 tprintf(
"Step %d: Dawg beam %d:\n", t, i);
727 DebugPath(charset, path);
730 for (
int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
732 ExtractPath(&prev->beams_[beam_index].get(i).data, &path);
733 tprintf(
"Step %d: Non-Dawg beam %d:\n", t, i);
734 DebugPath(charset, path);
742 for (
int tn = 0; tn <
TN_COUNT && total_beam == 0; ++tn) {
743 auto top_n = static_cast<TopNState>(tn);
744 for (
int index = 0; index <
kNumBeams; ++index) {
748 for (
int i = prev->beams_[index].size() - 1; i >= 0; --i) {
749 ContinueContext(&prev->beams_[index].get(i).data, index, outputs, top_n,
750 charset, dict_ratio, cert_offset, worst_dict_cert, step);
753 for (
int index = 0; index <
kNumBeams; ++index) {
755 total_beam += step->beams_[index].size();
760 for (
int c = 0; c <
NC_COUNT; ++c) {
761 if (step->best_initial_dawgs_[c].code >= 0) {
762 int index =
BeamIndex(
true, static_cast<NodeContinuation>(c), 0);
764 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
771 void RecodeBeamSearch::DecodeSecondaryStep(
const float* outputs,
int t,
772 double dict_ratio,
double cert_offset,
773 double worst_dict_cert,
775 if (t == secondary_beam_.
size()) secondary_beam_.
push_back(
new RecodeBeam);
776 RecodeBeam* step = secondary_beam_[t];
781 charset, dict_ratio, cert_offset, worst_dict_cert, step);
782 if (dict_ !=
nullptr) {
784 TN_TOP2, charset, dict_ratio, cert_offset,
785 worst_dict_cert, step);
788 RecodeBeam* prev = secondary_beam_[t - 1];
791 for (
int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
793 ExtractPath(&prev->beams_[beam_index].get(i).data, &path);
794 tprintf(
"Step %d: Dawg beam %d:\n", t, i);
795 DebugPath(charset, path);
798 for (
int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
800 ExtractPath(&prev->beams_[beam_index].get(i).data, &path);
801 tprintf(
"Step %d: Non-Dawg beam %d:\n", t, i);
802 DebugPath(charset, path);
810 for (
int tn = 0; tn <
TN_COUNT && total_beam == 0; ++tn) {
811 TopNState top_n = static_cast<TopNState>(tn);
812 for (
int index = 0; index <
kNumBeams; ++index) {
816 for (
int i = prev->beams_[index].size() - 1; i >= 0; --i) {
817 ContinueContext(&prev->beams_[index].get(i).data, index, outputs,
818 top_n, charset, dict_ratio, cert_offset,
819 worst_dict_cert, step);
822 for (
int index = 0; index <
kNumBeams; ++index) {
824 total_beam += step->beams_[index].size();
829 for (
int c = 0; c <
NC_COUNT; ++c) {
830 if (step->best_initial_dawgs_[c].code >= 0) {
831 int index =
BeamIndex(
true, static_cast<NodeContinuation>(c), 0);
833 PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
844 void RecodeBeamSearch::ContinueContext(
const RecodeNode* prev,
int index,
845 const float* outputs,
850 double worst_dict_cert,
852 RecodedCharID prefix;
853 RecodedCharID full_code;
854 const RecodeNode* previous = prev;
858 for (
int p = length - 1; p >= 0; --p, previous = previous->prev) {
859 while (previous !=
nullptr &&
860 (previous->duplicate || previous->code == null_char_)) {
861 previous = previous->prev;
863 if (previous !=
nullptr) {
864 prefix.Set(p, previous->code);
865 full_code.Set(p, previous->code);
868 if (prev !=
nullptr && !is_simple_text_) {
869 if (top_n_flags_[prev->code] == top_n_flag) {
873 PushDupOrNoDawgIfBetter(length,
true, prev->code, prev->unichar_id,
874 cert, worst_dict_cert, dict_ratio, use_dawgs,
878 prev->code != null_char_) {
880 outputs[null_char_]) +
882 PushDupOrNoDawgIfBetter(length,
true, prev->code, prev->unichar_id,
883 cert, worst_dict_cert, dict_ratio, use_dawgs,
888 if (prev->code != null_char_ && length > 0 &&
889 top_n_flags_[null_char_] == top_n_flag) {
894 PushDupOrNoDawgIfBetter(length,
false, null_char_, INVALID_UNICHAR_ID,
895 cert, worst_dict_cert, dict_ratio, use_dawgs,
900 if (final_codes !=
nullptr) {
901 for (
int i = 0; i < final_codes->
size(); ++i) {
902 int code = (*final_codes)[i];
903 if (top_n_flags_[code] != top_n_flag)
continue;
904 if (prev !=
nullptr && prev->code == code && !is_simple_text_)
continue;
907 full_code.Set(length, code);
910 if (length == 0 && code == null_char_) unichar_id = INVALID_UNICHAR_ID;
911 if (unichar_id != INVALID_UNICHAR_ID &&
912 charset !=
nullptr &&
915 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
917 if (top_n_flag ==
TN_TOP2 && code != null_char_) {
918 float prob = outputs[code] + outputs[null_char_];
920 prev->code != null_char_ &&
921 ((prev->code == top_code_ && code == second_code_) ||
922 (code == top_code_ && prev->code == second_code_))) {
923 prob += outputs[prev->code];
926 ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
932 if (next_codes !=
nullptr) {
933 for (
int i = 0; i < next_codes->
size(); ++i) {
934 int code = (*next_codes)[i];
935 if (top_n_flags_[code] != top_n_flag)
continue;
936 if (prev !=
nullptr && prev->code == code && !is_simple_text_)
continue;
938 PushDupOrNoDawgIfBetter(length + 1,
false, code, INVALID_UNICHAR_ID, cert,
939 worst_dict_cert, dict_ratio, use_dawgs,
941 if (top_n_flag ==
TN_TOP2 && code != null_char_) {
942 float prob = outputs[code] + outputs[null_char_];
944 prev->code != null_char_ &&
945 ((prev->code == top_code_ && code == second_code_) ||
946 (code == top_code_ && prev->code == second_code_))) {
947 prob += outputs[prev->code];
950 PushDupOrNoDawgIfBetter(length + 1,
false, code, INVALID_UNICHAR_ID,
951 cert, worst_dict_cert, dict_ratio, use_dawgs,
959 void RecodeBeamSearch::ContinueUnichar(
int code,
int unichar_id,
float cert,
960 float worst_dict_cert,
float dict_ratio,
962 const RecodeNode* prev,
965 if (cert > worst_dict_cert) {
966 ContinueDawg(code, unichar_id, cert, cont, prev, step);
970 PushHeapIfBetter(kBeamWidths[0], code, unichar_id,
TOP_CHOICE_PERM,
false,
971 false,
false,
false, cert * dict_ratio, prev,
nullptr,
973 if (dict_ !=
nullptr &&
979 float dawg_cert = cert;
993 dawg_cert *= dict_ratio;
994 PushInitialDawgIfBetter(code, unichar_id, permuter,
false,
false,
995 dawg_cert, cont, prev, step);
1003 void RecodeBeamSearch::ContinueDawg(
int code,
int unichar_id,
float cert,
1005 const RecodeNode* prev, RecodeBeam* step) {
1008 if (unichar_id == INVALID_UNICHAR_ID) {
1009 PushHeapIfBetter(kBeamWidths[0], code, unichar_id,
NO_PERM,
false,
false,
1010 false,
false, cert, prev,
nullptr, dawg_heap);
1015 if (prev !=
nullptr) score += prev->score;
1016 if (dawg_heap->size() >= kBeamWidths[0] &&
1017 score <= dawg_heap->PeekTop().data.score &&
1018 nodawg_heap->size() >= kBeamWidths[0] &&
1019 score <= nodawg_heap->PeekTop().data.score) {
1022 const RecodeNode* uni_prev = prev;
1025 while (uni_prev !=
nullptr &&
1026 (uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate))
1027 uni_prev = uni_prev->prev;
1029 if (uni_prev !=
nullptr && uni_prev->end_of_word) {
1032 PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter,
false,
1033 false, cert, cont, prev, step);
1034 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, uni_prev->permuter,
1035 false,
false,
false,
false, cert, prev,
nullptr,
1039 }
else if (uni_prev !=
nullptr && uni_prev->start_of_dawg &&
1045 DawgPositionVector initial_dawgs;
1046 auto* updated_dawgs =
new DawgPositionVector;
1047 DawgArgs dawg_args(&initial_dawgs, updated_dawgs,
NO_PERM);
1048 bool word_start =
false;
1049 if (uni_prev ==
nullptr) {
1053 }
else if (uni_prev->dawgs !=
nullptr) {
1055 dawg_args.active_dawgs = uni_prev->dawgs;
1056 word_start = uni_prev->start_of_dawg;
1060 auto permuter = static_cast<PermuterType>(
1064 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter,
false,
1065 word_start, dawg_args.valid_end,
false, cert, prev,
1066 dawg_args.updated_dawgs, dawg_heap);
1067 if (dawg_args.valid_end && !space_delimited_) {
1071 PushInitialDawgIfBetter(code, unichar_id, permuter, word_start,
true,
1072 cert, cont, prev, step);
1073 PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter,
false,
1074 word_start,
true,
false, cert, prev,
nullptr,
1078 delete updated_dawgs;
1085 void RecodeBeamSearch::PushInitialDawgIfBetter(
int code,
int unichar_id,
1087 bool start,
bool end,
float cert,
1089 const RecodeNode* prev,
1091 RecodeNode* best_initial_dawg = &step->best_initial_dawgs_[cont];
1093 if (prev !=
nullptr) score += prev->score;
1094 if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) {
1095 auto* initial_dawgs =
new DawgPositionVector;
1097 RecodeNode node(code, unichar_id, permuter,
true, start, end,
false, cert,
1098 score, prev, initial_dawgs,
1099 ComputeCodeHash(code,
false, prev));
1100 *best_initial_dawg = node;
1108 void RecodeBeamSearch::PushDupOrNoDawgIfBetter(
1109 int length,
bool dup,
int code,
int unichar_id,
float cert,
1110 float worst_dict_cert,
float dict_ratio,
bool use_dawgs,
1112 int index =
BeamIndex(use_dawgs, cont, length);
1114 if (cert > worst_dict_cert) {
1115 PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1116 prev ? prev->permuter :
NO_PERM,
false,
false,
false,
1117 dup, cert, prev,
nullptr, &step->beams_[index]);
1122 PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1124 false, dup, cert, prev,
nullptr, &step->beams_[index]);
1132 void RecodeBeamSearch::PushHeapIfBetter(
int max_size,
int code,
int unichar_id,
1134 bool word_start,
bool end,
bool dup,
1135 float cert,
const RecodeNode* prev,
1136 DawgPositionVector* d,
1139 if (prev !=
nullptr) score += prev->score;
1140 if (heap->size() < max_size || score > heap->PeekTop().data.score) {
1141 uint64_t hash = ComputeCodeHash(code, dup, prev);
1142 RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end,
1143 dup, cert, score, prev, d, hash);
1144 if (UpdateHeapIfMatched(&node, heap))
return;
1148 if (heap->size() > max_size) heap->Pop(&entry);
1156 void RecodeBeamSearch::PushHeapIfBetter(
int max_size, RecodeNode* node,
1158 if (heap->size() < max_size || node->score > heap->PeekTop().data.score) {
1159 if (UpdateHeapIfMatched(node, heap)) {
1165 if (heap->size() > max_size) heap->Pop(&entry);
1171 bool RecodeBeamSearch::UpdateHeapIfMatched(RecodeNode* new_node,
1177 for (
int i = 0; i < nodes->
size(); ++i) {
1178 RecodeNode& node = (*nodes)[i].data;
1179 if (node.code == new_node->code && node.code_hash == new_node->code_hash &&
1180 node.permuter == new_node->permuter &&
1181 node.start_of_dawg == new_node->start_of_dawg) {
1182 if (new_node->score > node.score) {
1186 (*nodes)[i].key = node.score;
1187 heap->Reshuffle(&(*nodes)[i]);
1196 uint64_t RecodeBeamSearch::ComputeCodeHash(
int code,
bool dup,
1197 const RecodeNode* prev)
const {
1198 uint64_t hash = prev ==
nullptr ? 0 : prev->code_hash;
1199 if (!dup && code != null_char_) {
1201 uint64_t carry = (((hash >> 32) * num_classes) >> 32);
1202 hash *= num_classes;
1213 void RecodeBeamSearch::ExtractBestPaths(
1217 const RecodeNode* best_node =
nullptr;
1218 const RecodeNode* second_best_node =
nullptr;
1219 const RecodeBeam* last_beam = beam_[beam_size_ - 1];
1220 for (
int c = 0; c <
NC_COUNT; ++c) {
1222 auto cont = static_cast<NodeContinuation>(c);
1223 for (
int is_dawg = 0; is_dawg < 2; ++is_dawg) {
1224 int beam_index =
BeamIndex(is_dawg, cont, 0);
1225 int heap_size = last_beam->beams_[beam_index].size();
1226 for (
int h = 0; h < heap_size; ++h) {
1227 const RecodeNode* node = &last_beam->beams_[beam_index].get(h).data;
1231 const RecodeNode* dawg_node = node;
1232 while (dawg_node !=
nullptr &&
1233 (dawg_node->unichar_id == INVALID_UNICHAR_ID ||
1234 dawg_node->duplicate))
1235 dawg_node = dawg_node->prev;
1236 if (dawg_node ==
nullptr ||
1237 (!dawg_node->end_of_word &&
1243 if (best_node ==
nullptr || node->score > best_node->score) {
1244 second_best_node = best_node;
1246 }
else if (second_best_node ==
nullptr ||
1247 node->score > second_best_node->score) {
1248 second_best_node = node;
1253 if (second_nodes !=
nullptr) ExtractPath(second_best_node, second_nodes);
1254 ExtractPath(best_node, best_nodes);
1259 void RecodeBeamSearch::ExtractPath(
1262 while (node !=
nullptr) {
1269 void RecodeBeamSearch::ExtractPath(
1271 int limiter)
const {
1272 int pathcounter = 0;
1274 while (node !=
nullptr && pathcounter < limiter) {
1283 void RecodeBeamSearch::DebugPath(
1286 for (
int c = 0; c < path.
size(); ++c) {
1287 const RecodeNode& node = *path[c];
1289 node.Print(null_char_, *unicharset, 1);
1294 void RecodeBeamSearch::DebugUnicharPath(
1299 int num_ids = unichar_ids.
size();
1300 double total_rating = 0.0;
1301 for (
int c = 0; c < num_ids; ++c) {
1302 int coord = xcoords[c];
1303 tprintf(
"%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c],
1305 certs[c], path[coord]->start_of_word, path[coord]->end_of_word,
1306 path[coord]->permuter);
1307 total_rating += ratings[c];
1309 tprintf(
"Path total rating = %g\n", total_rating);