tesseract  5.0.0-alpha-619-ge9db
recodebeam.cpp
Go to the documentation of this file.
1 // File: recodebeam.cpp
3 // Description: Beam search to decode from the re-encoded CJK as a sequence of
4 // smaller numbers in place of a single large code.
5 // Author: Ray Smith
6 //
7 // (C) Copyright 2015, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
17 //
19 
20 #include "recodebeam.h"
21 #include <deque>
22 #include <map>
23 #include <set>
24 #include <tuple>
25 #include <unordered_set>
26 #include <vector>
27 #include "networkio.h"
28 #include "pageres.h"
29 #include "unicharcompress.h"
30 
31 #include <algorithm>
32 
33 namespace tesseract {
34 
35 // The beam width at each code position.
36 const int RecodeBeamSearch::kBeamWidths[RecodedCharID::kMaxCodeLen + 1] = {
37  5, 10, 16, 16, 16, 16, 16, 16, 16, 16,
38 };
39 
40 static const char* kNodeContNames[] = {"Anything", "OnlyDup", "NoDup"};
41 
42 // Prints debug details of the node.
43 void RecodeNode::Print(int null_char, const UNICHARSET& unicharset,
44  int depth) const {
45  if (code == null_char) {
46  tprintf("null_char");
47  } else {
48  tprintf("label=%d, uid=%d=%s", code, unichar_id,
49  unicharset.debug_str(unichar_id).c_str());
50  }
51  tprintf(" score=%g, c=%g,%s%s%s perm=%d, hash=%" PRIx64, score, certainty,
52  start_of_dawg ? " DawgStart" : "", start_of_word ? " Start" : "",
53  end_of_word ? " End" : "", permuter, code_hash);
54  if (depth > 0 && prev != nullptr) {
55  tprintf(" prev:");
56  prev->Print(null_char, unicharset, depth - 1);
57  } else {
58  tprintf("\n");
59  }
60 }
61 
62 // Borrows the pointer, which is expected to survive until *this is deleted.
64  int null_char, bool simple_text, Dict* dict)
65  : recoder_(recoder),
66  beam_size_(0),
67  top_code_(-1),
68  second_code_(-1),
69  dict_(dict),
70  space_delimited_(true),
71  is_simple_text_(simple_text),
72  null_char_(null_char) {
73  if (dict_ != nullptr && !dict_->IsSpaceDelimitedLang())
74  space_delimited_ = false;
75 }
76 
77 // Decodes the set of network outputs, storing the lattice internally.
78 void RecodeBeamSearch::Decode(const NetworkIO& output, double dict_ratio,
79  double cert_offset, double worst_dict_cert,
80  const UNICHARSET* charset, int lstm_choice_mode) {
81  beam_size_ = 0;
82  int width = output.Width();
83  if (lstm_choice_mode) timesteps.clear();
84  for (int t = 0; t < width; ++t) {
85  ComputeTopN(output.f(t), output.NumFeatures(), kBeamWidths[0]);
86  DecodeStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert,
87  charset);
88  if (lstm_choice_mode) {
89  SaveMostCertainChoices(output.f(t), output.NumFeatures(), charset, t);
90  }
91  }
92 }
94  double dict_ratio, double cert_offset,
95  double worst_dict_cert,
96  const UNICHARSET* charset) {
97  beam_size_ = 0;
98  int width = output.dim1();
99  for (int t = 0; t < width; ++t) {
100  ComputeTopN(output[t], output.dim2(), kBeamWidths[0]);
101  DecodeStep(output[t], t, dict_ratio, cert_offset, worst_dict_cert, charset);
102  }
103 }
104 
106  double dict_ratio,
107  double cert_offset,
108  double worst_dict_cert,
109  const UNICHARSET* charset,
110  int lstm_choice_mode) {
111  secondary_beam_.clear();
112  if (character_boundaries_.size() < 2) return;
113  int width = output.Width();
114  int bucketNumber = 0;
115  for (int t = 0; t < width; ++t) {
116  while ((bucketNumber + 1) < character_boundaries_.size() &&
117  t >= character_boundaries_[bucketNumber + 1])
118  {
119  ++bucketNumber;
120  }
121  ComputeSecTopN(&(excludedUnichars)[bucketNumber], output.f(t),
122  output.NumFeatures(), kBeamWidths[0]);
123  DecodeSecondaryStep(output.f(t), t, dict_ratio, cert_offset, worst_dict_cert,
124  charset);
125  }
126 }
127 
128 void RecodeBeamSearch::SaveMostCertainChoices(const float* outputs,
129  int num_outputs,
130  const UNICHARSET* charset,
131  int xCoord) {
132  std::vector<std::pair<const char*, float>> choices;
133  for (int i = 0; i < num_outputs; ++i) {
134  if (outputs[i] >= 0.01f) {
135  const char* character;
136  if (i + 2 >= num_outputs) {
137  character = "";
138  } else if (i > 0) {
139  character = charset->id_to_unichar_ext(i + 2);
140  } else {
141  character = charset->id_to_unichar_ext(i);
142  }
143  size_t pos = 0;
144  // order the possible choices within one timestep
145  // beginning with the most likely
146  while (choices.size() > pos && choices[pos].second > outputs[i]) {
147  pos++;
148  }
149  choices.insert(choices.begin() + pos,
150  std::pair<const char*, float>(character, outputs[i]));
151  }
152  }
153  timesteps.push_back(choices);
154 }
155 
157  for (int i = 1; i < character_boundaries_.size(); ++i){
158  std::vector<std::vector<std::pair<const char*, float>>> segment;
159  for (int j = character_boundaries_[i - 1]; j < character_boundaries_[i]; ++j) {
160  segment.push_back(timesteps[j]);
161  }
162  segmentedTimesteps.push_back(segment);
163  }
164 }
165 std::vector<std::vector<std::pair<const char*, float>>>
167  std::vector<std::vector<std::vector<std::pair<const char*, float>>>>*
168  segmentedTimesteps) {
169  std::vector<std::vector<std::pair<const char*, float>>> combined_timesteps;
170  for (int i = 0; i < segmentedTimesteps->size(); ++i){
171  for (int j = 0; j < (*segmentedTimesteps)[i].size(); ++j) {
172  combined_timesteps.push_back((*segmentedTimesteps)[i][j]);
173  }
174  }
175  return combined_timesteps;
176 }
177 
178 void RecodeBeamSearch::calculateCharBoundaries(std::vector<int>* starts,
179  std::vector<int>* ends,
180  std::vector<int>* char_bounds_,
181  int maxWidth) {
182  char_bounds_->push_back(0);
183  for (int i = 0; i < ends->size(); ++i) {
184  int middle = ((*starts)[i+1]-(*ends)[i])/2;
185  char_bounds_->push_back((*ends)[i] + middle);
186  }
187  char_bounds_->pop_back();
188  char_bounds_->push_back(maxWidth);
189 }
190 
191 // Returns the best path as labels/scores/xcoords similar to simple CTC.
193  GenericVector<int>* labels, GenericVector<int>* xcoords) const {
194  labels->truncate(0);
195  xcoords->truncate(0);
197  ExtractBestPaths(&best_nodes, nullptr);
198  // Now just run CTC on the best nodes.
199  int t = 0;
200  int width = best_nodes.size();
201  while (t < width) {
202  int label = best_nodes[t]->code;
203  if (label != null_char_) {
204  labels->push_back(label);
205  xcoords->push_back(t);
206  }
207  while (++t < width && !is_simple_text_ && best_nodes[t]->code == label) {
208  }
209  }
210  xcoords->push_back(width);
211 }
212 
213 // Returns the best path as unichar-ids/certs/ratings/xcoords skipping
214 // duplicates, nulls and intermediate parts.
216  bool debug, const UNICHARSET* unicharset, GenericVector<int>* unichar_ids,
218  GenericVector<int>* xcoords) const {
220  ExtractBestPaths(&best_nodes, nullptr);
221  ExtractPathAsUnicharIds(best_nodes, unichar_ids, certs, ratings, xcoords);
222  if (debug) {
223  DebugPath(unicharset, best_nodes);
224  DebugUnicharPath(unicharset, best_nodes, *unichar_ids, *certs, *ratings,
225  *xcoords);
226  }
227 }
228 
229 // Returns the best path as a set of WERD_RES.
231  float scale_factor, bool debug,
232  const UNICHARSET* unicharset,
234  int lstm_choice_mode) {
235  words->truncate(0);
236  GenericVector<int> unichar_ids;
237  GenericVector<float> certs;
238  GenericVector<float> ratings;
239  GenericVector<int> xcoords;
242  character_boundaries_.clear();
243  ExtractBestPaths(&best_nodes, &second_nodes);
244  if (debug) {
245  DebugPath(unicharset, best_nodes);
246  ExtractPathAsUnicharIds(second_nodes, &unichar_ids, &certs, &ratings,
247  &xcoords);
248  tprintf("\nSecond choice path:\n");
249  DebugUnicharPath(unicharset, second_nodes, unichar_ids, certs, ratings,
250  xcoords);
251  }
252  // If lstm choice mode is required in granularity level 2, it stores the x
253  // Coordinates of every chosen character, to match the alternative choices to
254  // it.
255  ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings, &xcoords,
257  int num_ids = unichar_ids.size();
258  if (debug) {
259  DebugUnicharPath(unicharset, best_nodes, unichar_ids, certs, ratings,
260  xcoords);
261  }
262  // Convert labels to unichar-ids.
263  int word_end = 0;
264  float prev_space_cert = 0.0f;
265  for (int word_start = 0; word_start < num_ids; word_start = word_end) {
266  for (word_end = word_start + 1; word_end < num_ids; ++word_end) {
267  // A word is terminated when a space character or start_of_word flag is
268  // hit. We also want to force a separate word for every non
269  // space-delimited character when not in a dictionary context.
270  if (unichar_ids[word_end] == UNICHAR_SPACE) break;
271  int index = xcoords[word_end];
272  if (best_nodes[index]->start_of_word) break;
273  if (best_nodes[index]->permuter == TOP_CHOICE_PERM &&
274  (!unicharset->IsSpaceDelimited(unichar_ids[word_end]) ||
275  !unicharset->IsSpaceDelimited(unichar_ids[word_end - 1])))
276  break;
277  }
278  float space_cert = 0.0f;
279  if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE)
280  space_cert = certs[word_end];
281  bool leading_space =
282  word_start > 0 && unichar_ids[word_start - 1] == UNICHAR_SPACE;
283  // Create a WERD_RES for the output word.
284  WERD_RES* word_res =
285  InitializeWord(leading_space, line_box, word_start, word_end,
286  std::min(space_cert, prev_space_cert), unicharset,
287  xcoords, scale_factor);
288  for (int i = word_start; i < word_end; ++i) {
289  auto* choices = new BLOB_CHOICE_LIST;
290  BLOB_CHOICE_IT bc_it(choices);
291  auto* choice = new BLOB_CHOICE(
292  unichar_ids[i], ratings[i], certs[i], -1, 1.0f,
293  static_cast<float>(INT16_MAX), 0.0f, BCC_STATIC_CLASSIFIER);
294  int col = i - word_start;
295  choice->set_matrix_cell(col, col);
296  bc_it.add_after_then_move(choice);
297  word_res->ratings->put(col, col, choices);
298  }
299  int index = xcoords[word_end - 1];
300  word_res->FakeWordFromRatings(best_nodes[index]->permuter);
301  words->push_back(word_res);
302  prev_space_cert = space_cert;
303  if (word_end < num_ids && unichar_ids[word_end] == UNICHAR_SPACE)
304  ++word_end;
305  }
306 }
307 
308 struct greater_than {
309  inline bool operator()(const RecodeNode*& node1,
310  const RecodeNode*& node2) {
311  return (node1->score > node2->score);
312  }
313 };
314 
315 void RecodeBeamSearch::PrintBeam2(bool uids, int num_outputs,
316  const UNICHARSET* charset,
317  bool secondary) const {
318  std::vector<std::vector<const RecodeNode*>> topology;
319  std::unordered_set<const RecodeNode*> visited;
320  const PointerVector<RecodeBeam>* beam = !secondary ? &beam_ : &secondary_beam_;
321  // create the topology
322  for (int step = beam->size()-1; step >=0; --step) {
323  std::vector<const RecodeNode*> layer;
324  topology.push_back(layer);
325  }
326  // fill the topology with depths first
327  for (int step = beam->size() - 1; step >= 0; --step) {
329  beam->get(step)->beams_->heap();
330  for (int node = 0; node < heaps->size(); ++node) {
331  int backtracker = 0;
332  const RecodeNode* curr = &heaps->get(node).data;
333  while (curr != nullptr && !visited.count(curr)) {
334  visited.insert(curr);
335  topology[step - backtracker].push_back(curr);
336  curr = curr->prev;
337  ++backtracker;
338  }
339  }
340  }
341  int ct = 0;
342  int cb = 1;
343  for (std::vector<const RecodeNode*> layer: topology) {
344  if (cb >= character_boundaries_.size())
345  break;
346  if (ct == character_boundaries_[cb]) {
347  tprintf("***\n");
348  ++cb;
349  }
350  for (const RecodeNode* node : layer) {
351  const char* code;
352  int intCode;
353  if (node->unichar_id != INVALID_UNICHAR_ID) {
354  code = charset->id_to_unichar(node->unichar_id);
355  intCode = node->unichar_id;
356  } else if(node->code == null_char_) {
357  intCode = 0;
358  code = " ";
359  } else {
360  intCode = 666;
361  code = "*";
362  }
363  int intPrevCode = 0;
364  const char* prevCode;
365  float prevScore = 0;
366  if (node->prev != nullptr) {
367  prevScore = node->prev->score;
368  if (node->prev->unichar_id != INVALID_UNICHAR_ID) {
369  prevCode = charset->id_to_unichar(node->prev->unichar_id);
370  intPrevCode = node->prev->unichar_id;
371  } else if (node->code == null_char_) {
372  intPrevCode = 0;
373  prevCode = " ";
374  } else {
375  prevCode = "*";
376  intPrevCode = 666;
377  }
378  } else {
379  prevCode = " ";
380  }
381  if (uids) {
382  tprintf("%x(|)%f(>)%x(|)%f\n", intPrevCode,
383  prevScore, intCode, node->score);
384  } else {
385  tprintf("%s(|)%f(>)%s(|)%f\n", prevCode,
386  prevScore, code, node->score);
387  }
388  }
389  tprintf("-\n");
390  ++ct;
391  }
392  tprintf("***\n");
393 }
394 
396  GenericVector<tesseract::RecodePair>* heaps = nullptr;
397  PointerVector<RecodeBeam>* currentBeam = nullptr;
398  if (character_boundaries_.size() < 2) return;
399  // For the first iteration the original beam is analyzed. After that a
400  // new beam is calculated based on the results from the original beam.
401  if (secondary_beam_.empty()) {
402  currentBeam = &beam_;
403  } else {
404  currentBeam = &secondary_beam_;
405  }
406  character_boundaries_[0] = 0;
407  for (int j = 1; j < character_boundaries_.size(); ++j) {
408  GenericVector<int> unichar_ids;
409  GenericVector<float> certs;
410  GenericVector<float> ratings;
411  GenericVector<int> xcoords;
412  int backpath = character_boundaries_[j] - character_boundaries_[j - 1];
413  heaps = currentBeam->get(character_boundaries_[j] - 1)->beams_->heap();
415  std::vector<const RecodeNode*> best;
416  // Scan the segmented node chain for valid unichar ids.
417  for (int i = 0; i < heaps->size(); ++i) {
418  bool validChar = false;
419  int backcounter = 0;
420  const RecodeNode* node = &heaps->get(i).data;
421  while (node != nullptr && backcounter < backpath) {
422  if (node->code != null_char_ && node->unichar_id != INVALID_UNICHAR_ID) {
423  validChar = true;
424  break;
425  }
426  node = node->prev;
427  ++backcounter;
428  }
429  if (validChar) best.push_back(&heaps->get(i).data);
430  }
431  // find the best rated segmented node chain and extract the unichar id.
432  if (!best.empty()) {
433  std::sort(best.begin(), best.end(), greater_than());
434  ExtractPath(best[0], &best_nodes, backpath);
435  ExtractPathAsUnicharIds(best_nodes, &unichar_ids, &certs, &ratings,
436  &xcoords);
437  }
438  if (!unichar_ids.empty()) {
439  int bestPos = 0;
440  for (int i = 1; i < unichar_ids.size(); ++i) {
441  if (ratings[i] < ratings[bestPos])
442  bestPos = i;
443  }
444  int bestCode = -10;
445  for (int i = 0; i < best_nodes.size(); ++i) {
446  if (best_nodes[i]->unichar_id == unichar_ids[bestPos]) {
447  bestCode = best_nodes[i]->code;
448  }
449  }
450  // Exclude the best choice for the followup decoding.
451  std::unordered_set<int> excludeCodeList;
452  for (int node = 0; node < best_nodes.size(); ++node) {
453  if (best_nodes[node]->code != null_char_) {
454  excludeCodeList.insert(best_nodes[node]->code);
455  }
456  }
457  if (j - 1 < excludedUnichars.size()) {
458  for (auto elem : excludeCodeList) {
459  excludedUnichars[j - 1].insert(elem);
460  }
461  } else {
462  excludedUnichars.push_back(excludeCodeList);
463  }
464  // Save the best choice for the choice iterator.
465  if (j - 1 < ctc_choices.size()) {
466  int id = unichar_ids[bestPos];
467  const char* result = unicharset->id_to_unichar_ext(id);
468  float rating = ratings[bestPos];
469  ctc_choices[j - 1].push_back(
470  std::pair<const char*, float>(result, rating));
471  } else {
472  std::vector<std::pair<const char*, float>> choice;
473  int id = unichar_ids[bestPos];
474  const char* result = unicharset->id_to_unichar_ext(id);
475  float rating = ratings[bestPos];
476  choice.push_back(std::pair<const char*, float>(result, rating));
477  ctc_choices.push_back(choice);
478  }
479  // fill the blank spot with an empty array
480  } else {
481  if (j - 1 >= excludedUnichars.size()) {
482  std::unordered_set<int> excludeCodeList;
483  excludedUnichars.push_back(excludeCodeList);
484  }
485  if (j - 1 >= ctc_choices.size()) {
486  std::vector<std::pair<const char*, float>> choice;
487  ctc_choices.push_back(choice);
488  }
489  }
490  }
491  secondary_beam_.clear();
492 }
493 
494 // Generates debug output of the content of the beams after a Decode.
495 void RecodeBeamSearch::DebugBeams(const UNICHARSET& unicharset) const {
496  for (int p = 0; p < beam_size_; ++p) {
497  for (int d = 0; d < 2; ++d) {
498  for (int c = 0; c < NC_COUNT; ++c) {
499  auto cont = static_cast<NodeContinuation>(c);
500  int index = BeamIndex(d, cont, 0);
501  if (beam_[p]->beams_[index].empty()) continue;
502  // Print all the best scoring nodes for each unichar found.
503  tprintf("Position %d: %s+%s beam\n", p, d ? "Dict" : "Non-Dict",
504  kNodeContNames[c]);
505  DebugBeamPos(unicharset, beam_[p]->beams_[index]);
506  }
507  }
508  }
509 }
510 
511 // Generates debug output of the content of a single beam position.
512 void RecodeBeamSearch::DebugBeamPos(const UNICHARSET& unicharset,
513  const RecodeHeap& heap) const {
514  GenericVector<const RecodeNode*> unichar_bests;
515  unichar_bests.init_to_size(unicharset.size(), nullptr);
516  const RecodeNode* null_best = nullptr;
517  int heap_size = heap.size();
518  for (int i = 0; i < heap_size; ++i) {
519  const RecodeNode* node = &heap.get(i).data;
520  if (node->unichar_id == INVALID_UNICHAR_ID) {
521  if (null_best == nullptr || null_best->score < node->score)
522  null_best = node;
523  } else {
524  if (unichar_bests[node->unichar_id] == nullptr ||
525  unichar_bests[node->unichar_id]->score < node->score) {
526  unichar_bests[node->unichar_id] = node;
527  }
528  }
529  }
530  for (int u = 0; u < unichar_bests.size(); ++u) {
531  if (unichar_bests[u] != nullptr) {
532  const RecodeNode& node = *unichar_bests[u];
533  node.Print(null_char_, unicharset, 1);
534  }
535  }
536  if (null_best != nullptr) {
537  null_best->Print(null_char_, unicharset, 1);
538  }
539 }
540 
541 // Returns the given best_nodes as unichar-ids/certs/ratings/xcoords skipping
542 // duplicates, nulls and intermediate parts.
543 /* static */
544 void RecodeBeamSearch::ExtractPathAsUnicharIds(
545  const GenericVector<const RecodeNode*>& best_nodes,
546  GenericVector<int>* unichar_ids, GenericVector<float>* certs,
547  GenericVector<float>* ratings, GenericVector<int>* xcoords,
548  std::vector<int>* character_boundaries) {
549  unichar_ids->truncate(0);
550  certs->truncate(0);
551  ratings->truncate(0);
552  xcoords->truncate(0);
553  std::vector<int> starts;
554  std::vector<int> ends;
555  // Backtrack extracting only valid, non-duplicate unichar-ids.
556  int t = 0;
557  int width = best_nodes.size();
558  while (t < width) {
559  double certainty = 0.0;
560  double rating = 0.0;
561  while (t < width && best_nodes[t]->unichar_id == INVALID_UNICHAR_ID) {
562  double cert = best_nodes[t++]->certainty;
563  if (cert < certainty) certainty = cert;
564  rating -= cert;
565  }
566  starts.push_back(t);
567  if (t < width) {
568  int unichar_id = best_nodes[t]->unichar_id;
569  if (unichar_id == UNICHAR_SPACE && !certs->empty() &&
570  best_nodes[t]->permuter != NO_PERM) {
571  // All the rating and certainty go on the previous character except
572  // for the space itself.
573  if (certainty < certs->back()) certs->back() = certainty;
574  ratings->back() += rating;
575  certainty = 0.0;
576  rating = 0.0;
577  }
578  unichar_ids->push_back(unichar_id);
579  xcoords->push_back(t);
580  do {
581  double cert = best_nodes[t++]->certainty;
582  // Special-case NO-PERM space to forget the certainty of the previous
583  // nulls. See long comment in ContinueContext.
584  if (cert < certainty || (unichar_id == UNICHAR_SPACE &&
585  best_nodes[t - 1]->permuter == NO_PERM)) {
586  certainty = cert;
587  }
588  rating -= cert;
589  } while (t < width && best_nodes[t]->duplicate);
590  ends.push_back(t);
591  certs->push_back(certainty);
592  ratings->push_back(rating);
593  } else if (!certs->empty()) {
594  if (certainty < certs->back()) certs->back() = certainty;
595  ratings->back() += rating;
596  }
597  }
598  starts.push_back(width);
599  if (character_boundaries != nullptr) {
600  calculateCharBoundaries(&starts, &ends, character_boundaries, width);
601  }
602  xcoords->push_back(width);
603 }
604 
605 // Sets up a word with the ratings matrix and fake blobs with boxes in the
606 // right places.
607 WERD_RES* RecodeBeamSearch::InitializeWord(bool leading_space,
608  const TBOX& line_box, int word_start,
609  int word_end, float space_certainty,
610  const UNICHARSET* unicharset,
611  const GenericVector<int>& xcoords,
612  float scale_factor) {
613  // Make a fake blob for each non-zero label.
614  C_BLOB_LIST blobs;
615  C_BLOB_IT b_it(&blobs);
616  for (int i = word_start; i < word_end; ++i) {
617  if (character_boundaries_.size() > (i + 1)) {
619  line_box.height());
620  box.scale(scale_factor);
621  box.move(ICOORD(line_box.left(), line_box.bottom()));
622  box.set_top(line_box.top());
623  b_it.add_after_then_move(C_BLOB::FakeBlob(box));
624  }
625  }
626  // Make a fake word from the blobs.
627  WERD* word = new WERD(&blobs, leading_space, nullptr);
628  // Make a WERD_RES from the word.
629  auto* word_res = new WERD_RES(word);
630  word_res->end = word_end - word_start + leading_space;
631  word_res->uch_set = unicharset;
632  word_res->combination = true; // Give it ownership of the word.
633  word_res->space_certainty = space_certainty;
634  word_res->ratings = new MATRIX(word_end - word_start, 1);
635  return word_res;
636 }
637 
638 // Fills top_n_flags_ with bools that are true iff the corresponding output
639 // is one of the top_n.
640 void RecodeBeamSearch::ComputeTopN(const float* outputs, int num_outputs,
641  int top_n) {
642  top_n_flags_.init_to_size(num_outputs, TN_ALSO_RAN);
643  top_code_ = -1;
644  second_code_ = -1;
645  top_heap_.clear();
646  for (int i = 0; i < num_outputs; ++i) {
647  if (top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key) {
648  TopPair entry(outputs[i], i);
649  top_heap_.Push(&entry);
650  if (top_heap_.size() > top_n) top_heap_.Pop(&entry);
651  }
652  }
653  while (!top_heap_.empty()) {
654  TopPair entry;
655  top_heap_.Pop(&entry);
656  if (top_heap_.size() > 1) {
657  top_n_flags_[entry.data] = TN_TOPN;
658  } else {
659  top_n_flags_[entry.data] = TN_TOP2;
660  if (top_heap_.empty())
661  top_code_ = entry.data;
662  else
663  second_code_ = entry.data;
664  }
665  }
666  top_n_flags_[null_char_] = TN_TOP2;
667 }
668 
669 void RecodeBeamSearch::ComputeSecTopN(std::unordered_set<int>* exList,
670  const float* outputs, int num_outputs,
671  int top_n) {
672  top_n_flags_.init_to_size(num_outputs, TN_ALSO_RAN);
673  top_code_ = -1;
674  second_code_ = -1;
675  top_heap_.clear();
676  for (int i = 0; i < num_outputs; ++i) {
677  if ((top_heap_.size() < top_n || outputs[i] > top_heap_.PeekTop().key)
678  && !exList->count(i)) {
679  TopPair entry(outputs[i], i);
680  top_heap_.Push(&entry);
681  if (top_heap_.size() > top_n) top_heap_.Pop(&entry);
682  }
683  }
684  while (!top_heap_.empty()) {
685  TopPair entry;
686  top_heap_.Pop(&entry);
687  if (top_heap_.size() > 1) {
688  top_n_flags_[entry.data] = TN_TOPN;
689  } else {
690  top_n_flags_[entry.data] = TN_TOP2;
691  if (top_heap_.empty())
692  top_code_ = entry.data;
693  else
694  second_code_ = entry.data;
695  }
696  }
697  top_n_flags_[null_char_] = TN_TOP2;
698 }
699 
700 // Adds the computation for the current time-step to the beam. Call at each
701 // time-step in sequence from left to right. outputs is the activation vector
702 // for the current timestep.
703 void RecodeBeamSearch::DecodeStep(const float* outputs, int t,
704  double dict_ratio, double cert_offset,
705  double worst_dict_cert,
706  const UNICHARSET* charset, bool debug) {
707  if (t == beam_.size()) beam_.push_back(new RecodeBeam);
708  RecodeBeam* step = beam_[t];
709  beam_size_ = t + 1;
710  step->Clear();
711  if (t == 0) {
712  // The first step can only use singles and initials.
713  ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2,
714  charset, dict_ratio, cert_offset, worst_dict_cert, step);
715  if (dict_ != nullptr) {
716  ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs, TN_TOP2,
717  charset, dict_ratio, cert_offset, worst_dict_cert, step);
718  }
719  } else {
720  RecodeBeam* prev = beam_[t - 1];
721  if (debug) {
722  int beam_index = BeamIndex(true, NC_ANYTHING, 0);
723  for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
725  ExtractPath(&prev->beams_[beam_index].get(i).data, &path);
726  tprintf("Step %d: Dawg beam %d:\n", t, i);
727  DebugPath(charset, path);
728  }
729  beam_index = BeamIndex(false, NC_ANYTHING, 0);
730  for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
732  ExtractPath(&prev->beams_[beam_index].get(i).data, &path);
733  tprintf("Step %d: Non-Dawg beam %d:\n", t, i);
734  DebugPath(charset, path);
735  }
736  }
737  int total_beam = 0;
738  // Work through the scores by group (top-2, top-n, the rest) while the beam
739  // is empty. This enables extending the context using only the top-n results
740  // first, which may have an empty intersection with the valid codes, so we
741  // fall back to the rest if the beam is empty.
742  for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) {
743  auto top_n = static_cast<TopNState>(tn);
744  for (int index = 0; index < kNumBeams; ++index) {
745  // Working backwards through the heaps doesn't guarantee that we see the
746  // best first, but it comes before a lot of the worst, so it is slightly
747  // more efficient than going forwards.
748  for (int i = prev->beams_[index].size() - 1; i >= 0; --i) {
749  ContinueContext(&prev->beams_[index].get(i).data, index, outputs, top_n,
750  charset, dict_ratio, cert_offset, worst_dict_cert, step);
751  }
752  }
753  for (int index = 0; index < kNumBeams; ++index) {
755  total_beam += step->beams_[index].size();
756  }
757  }
758  // Special case for the best initial dawg. Push it on the heap if good
759  // enough, but there is only one, so it doesn't blow up the beam.
760  for (int c = 0; c < NC_COUNT; ++c) {
761  if (step->best_initial_dawgs_[c].code >= 0) {
762  int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0);
763  RecodeHeap* dawg_heap = &step->beams_[index];
764  PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
765  dawg_heap);
766  }
767  }
768  }
769 }
770 
771 void RecodeBeamSearch::DecodeSecondaryStep(const float* outputs, int t,
772  double dict_ratio, double cert_offset,
773  double worst_dict_cert,
774  const UNICHARSET* charset, bool debug) {
775  if (t == secondary_beam_.size()) secondary_beam_.push_back(new RecodeBeam);
776  RecodeBeam* step = secondary_beam_[t];
777  step->Clear();
778  if (t == 0) {
779  // The first step can only use singles and initials.
780  ContinueContext(nullptr, BeamIndex(false, NC_ANYTHING, 0), outputs, TN_TOP2,
781  charset, dict_ratio, cert_offset, worst_dict_cert, step);
782  if (dict_ != nullptr) {
783  ContinueContext(nullptr, BeamIndex(true, NC_ANYTHING, 0), outputs,
784  TN_TOP2, charset, dict_ratio, cert_offset,
785  worst_dict_cert, step);
786  }
787  } else {
788  RecodeBeam* prev = secondary_beam_[t - 1];
789  if (debug) {
790  int beam_index = BeamIndex(true, NC_ANYTHING, 0);
791  for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
793  ExtractPath(&prev->beams_[beam_index].get(i).data, &path);
794  tprintf("Step %d: Dawg beam %d:\n", t, i);
795  DebugPath(charset, path);
796  }
797  beam_index = BeamIndex(false, NC_ANYTHING, 0);
798  for (int i = prev->beams_[beam_index].size() - 1; i >= 0; --i) {
800  ExtractPath(&prev->beams_[beam_index].get(i).data, &path);
801  tprintf("Step %d: Non-Dawg beam %d:\n", t, i);
802  DebugPath(charset, path);
803  }
804  }
805  int total_beam = 0;
806  // Work through the scores by group (top-2, top-n, the rest) while the beam
807  // is empty. This enables extending the context using only the top-n results
808  // first, which may have an empty intersection with the valid codes, so we
809  // fall back to the rest if the beam is empty.
810  for (int tn = 0; tn < TN_COUNT && total_beam == 0; ++tn) {
811  TopNState top_n = static_cast<TopNState>(tn);
812  for (int index = 0; index < kNumBeams; ++index) {
813  // Working backwards through the heaps doesn't guarantee that we see the
814  // best first, but it comes before a lot of the worst, so it is slightly
815  // more efficient than going forwards.
816  for (int i = prev->beams_[index].size() - 1; i >= 0; --i) {
817  ContinueContext(&prev->beams_[index].get(i).data, index, outputs,
818  top_n, charset, dict_ratio, cert_offset,
819  worst_dict_cert, step);
820  }
821  }
822  for (int index = 0; index < kNumBeams; ++index) {
824  total_beam += step->beams_[index].size();
825  }
826  }
827  // Special case for the best initial dawg. Push it on the heap if good
828  // enough, but there is only one, so it doesn't blow up the beam.
829  for (int c = 0; c < NC_COUNT; ++c) {
830  if (step->best_initial_dawgs_[c].code >= 0) {
831  int index = BeamIndex(true, static_cast<NodeContinuation>(c), 0);
832  RecodeHeap* dawg_heap = &step->beams_[index];
833  PushHeapIfBetter(kBeamWidths[0], &step->best_initial_dawgs_[c],
834  dawg_heap);
835  }
836  }
837  }
838 }
839 
840 // Adds to the appropriate beams the legal (according to recoder)
841 // continuations of context prev, which is of the given length, using the
842 // given network outputs to provide scores to the choices. Uses only those
843 // choices for which top_n_flags[index] == top_n_flag.
844 void RecodeBeamSearch::ContinueContext(const RecodeNode* prev, int index,
845  const float* outputs,
846  TopNState top_n_flag,
847  const UNICHARSET* charset,
848  double dict_ratio,
849  double cert_offset,
850  double worst_dict_cert,
851  RecodeBeam* step) {
852  RecodedCharID prefix;
853  RecodedCharID full_code;
854  const RecodeNode* previous = prev;
855  int length = LengthFromBeamsIndex(index);
856  bool use_dawgs = IsDawgFromBeamsIndex(index);
857  NodeContinuation prev_cont = ContinuationFromBeamsIndex(index);
858  for (int p = length - 1; p >= 0; --p, previous = previous->prev) {
859  while (previous != nullptr &&
860  (previous->duplicate || previous->code == null_char_)) {
861  previous = previous->prev;
862  }
863  if (previous != nullptr) {
864  prefix.Set(p, previous->code);
865  full_code.Set(p, previous->code);
866  }
867  }
868  if (prev != nullptr && !is_simple_text_) {
869  if (top_n_flags_[prev->code] == top_n_flag) {
870  if (prev_cont != NC_NO_DUP) {
871  float cert =
872  NetworkIO::ProbToCertainty(outputs[prev->code]) + cert_offset;
873  PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id,
874  cert, worst_dict_cert, dict_ratio, use_dawgs,
875  NC_ANYTHING, prev, step);
876  }
877  if (prev_cont == NC_ANYTHING && top_n_flag == TN_TOP2 &&
878  prev->code != null_char_) {
879  float cert = NetworkIO::ProbToCertainty(outputs[prev->code] +
880  outputs[null_char_]) +
881  cert_offset;
882  PushDupOrNoDawgIfBetter(length, true, prev->code, prev->unichar_id,
883  cert, worst_dict_cert, dict_ratio, use_dawgs,
884  NC_NO_DUP, prev, step);
885  }
886  }
887  if (prev_cont == NC_ONLY_DUP) return;
888  if (prev->code != null_char_ && length > 0 &&
889  top_n_flags_[null_char_] == top_n_flag) {
890  // Allow nulls within multi code sequences, as the nulls within are not
891  // explicitly included in the code sequence.
892  float cert =
893  NetworkIO::ProbToCertainty(outputs[null_char_]) + cert_offset;
894  PushDupOrNoDawgIfBetter(length, false, null_char_, INVALID_UNICHAR_ID,
895  cert, worst_dict_cert, dict_ratio, use_dawgs,
896  NC_ANYTHING, prev, step);
897  }
898  }
899  const GenericVector<int>* final_codes = recoder_.GetFinalCodes(prefix);
900  if (final_codes != nullptr) {
901  for (int i = 0; i < final_codes->size(); ++i) {
902  int code = (*final_codes)[i];
903  if (top_n_flags_[code] != top_n_flag) continue;
904  if (prev != nullptr && prev->code == code && !is_simple_text_) continue;
905  float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
906  if (cert < kMinCertainty && code != null_char_) continue;
907  full_code.Set(length, code);
908  int unichar_id = recoder_.DecodeUnichar(full_code);
909  // Map the null char to INVALID.
910  if (length == 0 && code == null_char_) unichar_id = INVALID_UNICHAR_ID;
911  if (unichar_id != INVALID_UNICHAR_ID &&
912  charset != nullptr &&
913  !charset->get_enabled(unichar_id))
914  continue; // disabled by whitelist/blacklist
915  ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
916  use_dawgs, NC_ANYTHING, prev, step);
917  if (top_n_flag == TN_TOP2 && code != null_char_) {
918  float prob = outputs[code] + outputs[null_char_];
919  if (prev != nullptr && prev_cont == NC_ANYTHING &&
920  prev->code != null_char_ &&
921  ((prev->code == top_code_ && code == second_code_) ||
922  (code == top_code_ && prev->code == second_code_))) {
923  prob += outputs[prev->code];
924  }
925  float cert = NetworkIO::ProbToCertainty(prob) + cert_offset;
926  ContinueUnichar(code, unichar_id, cert, worst_dict_cert, dict_ratio,
927  use_dawgs, NC_ONLY_DUP, prev, step);
928  }
929  }
930  }
931  const GenericVector<int>* next_codes = recoder_.GetNextCodes(prefix);
932  if (next_codes != nullptr) {
933  for (int i = 0; i < next_codes->size(); ++i) {
934  int code = (*next_codes)[i];
935  if (top_n_flags_[code] != top_n_flag) continue;
936  if (prev != nullptr && prev->code == code && !is_simple_text_) continue;
937  float cert = NetworkIO::ProbToCertainty(outputs[code]) + cert_offset;
938  PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID, cert,
939  worst_dict_cert, dict_ratio, use_dawgs,
940  NC_ANYTHING, prev, step);
941  if (top_n_flag == TN_TOP2 && code != null_char_) {
942  float prob = outputs[code] + outputs[null_char_];
943  if (prev != nullptr && prev_cont == NC_ANYTHING &&
944  prev->code != null_char_ &&
945  ((prev->code == top_code_ && code == second_code_) ||
946  (code == top_code_ && prev->code == second_code_))) {
947  prob += outputs[prev->code];
948  }
949  float cert = NetworkIO::ProbToCertainty(prob) + cert_offset;
950  PushDupOrNoDawgIfBetter(length + 1, false, code, INVALID_UNICHAR_ID,
951  cert, worst_dict_cert, dict_ratio, use_dawgs,
952  NC_ONLY_DUP, prev, step);
953  }
954  }
955  }
956 }
957 
958 // Continues for a new unichar, using dawg or non-dawg as per flag.
959 void RecodeBeamSearch::ContinueUnichar(int code, int unichar_id, float cert,
960  float worst_dict_cert, float dict_ratio,
961  bool use_dawgs, NodeContinuation cont,
962  const RecodeNode* prev,
963  RecodeBeam* step) {
964  if (use_dawgs) {
965  if (cert > worst_dict_cert) {
966  ContinueDawg(code, unichar_id, cert, cont, prev, step);
967  }
968  } else {
969  RecodeHeap* nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)];
970  PushHeapIfBetter(kBeamWidths[0], code, unichar_id, TOP_CHOICE_PERM, false,
971  false, false, false, cert * dict_ratio, prev, nullptr,
972  nodawg_heap);
973  if (dict_ != nullptr &&
974  ((unichar_id == UNICHAR_SPACE && cert > worst_dict_cert) ||
975  !dict_->getUnicharset().IsSpaceDelimited(unichar_id))) {
976  // Any top choice position that can start a new word, ie a space or
977  // any non-space-delimited character, should also be considered
978  // by the dawg search, so push initial dawg to the dawg heap.
979  float dawg_cert = cert;
980  PermuterType permuter = TOP_CHOICE_PERM;
981  // Since we use the space either side of a dictionary word in the
982  // certainty of the word, (to properly handle weak spaces) and the
983  // space is coming from a non-dict word, we need special conditions
984  // to avoid degrading the certainty of the dict word that follows.
985  // With a space we don't multiply the certainty by dict_ratio, and we
986  // flag the space with NO_PERM to indicate that we should not use the
987  // predecessor nulls to generate the confidence for the space, as they
988  // have already been multiplied by dict_ratio, and we can't go back to
989  // insert more entries in any previous heaps.
990  if (unichar_id == UNICHAR_SPACE)
991  permuter = NO_PERM;
992  else
993  dawg_cert *= dict_ratio;
994  PushInitialDawgIfBetter(code, unichar_id, permuter, false, false,
995  dawg_cert, cont, prev, step);
996  }
997  }
998 }
999 
1000 // Adds a RecodeNode composed of the tuple (code, unichar_id, cert, prev,
1001 // appropriate-dawg-args, cert) to the given heap (dawg_beam_) if unichar_id
1002 // is a valid continuation of whatever is in prev.
1003 void RecodeBeamSearch::ContinueDawg(int code, int unichar_id, float cert,
1004  NodeContinuation cont,
1005  const RecodeNode* prev, RecodeBeam* step) {
1006  RecodeHeap* dawg_heap = &step->beams_[BeamIndex(true, cont, 0)];
1007  RecodeHeap* nodawg_heap = &step->beams_[BeamIndex(false, cont, 0)];
1008  if (unichar_id == INVALID_UNICHAR_ID) {
1009  PushHeapIfBetter(kBeamWidths[0], code, unichar_id, NO_PERM, false, false,
1010  false, false, cert, prev, nullptr, dawg_heap);
1011  return;
1012  }
1013  // Avoid dictionary probe if score a total loss.
1014  float score = cert;
1015  if (prev != nullptr) score += prev->score;
1016  if (dawg_heap->size() >= kBeamWidths[0] &&
1017  score <= dawg_heap->PeekTop().data.score &&
1018  nodawg_heap->size() >= kBeamWidths[0] &&
1019  score <= nodawg_heap->PeekTop().data.score) {
1020  return;
1021  }
1022  const RecodeNode* uni_prev = prev;
1023  // Prev may be a partial code, null_char, or duplicate, so scan back to the
1024  // last valid unichar_id.
1025  while (uni_prev != nullptr &&
1026  (uni_prev->unichar_id == INVALID_UNICHAR_ID || uni_prev->duplicate))
1027  uni_prev = uni_prev->prev;
1028  if (unichar_id == UNICHAR_SPACE) {
1029  if (uni_prev != nullptr && uni_prev->end_of_word) {
1030  // Space is good. Push initial state, to the dawg beam and a regular
1031  // space to the top choice beam.
1032  PushInitialDawgIfBetter(code, unichar_id, uni_prev->permuter, false,
1033  false, cert, cont, prev, step);
1034  PushHeapIfBetter(kBeamWidths[0], code, unichar_id, uni_prev->permuter,
1035  false, false, false, false, cert, prev, nullptr,
1036  nodawg_heap);
1037  }
1038  return;
1039  } else if (uni_prev != nullptr && uni_prev->start_of_dawg &&
1040  uni_prev->unichar_id != UNICHAR_SPACE &&
1041  dict_->getUnicharset().IsSpaceDelimited(uni_prev->unichar_id) &&
1042  dict_->getUnicharset().IsSpaceDelimited(unichar_id)) {
1043  return; // Can't break words between space delimited chars.
1044  }
1045  DawgPositionVector initial_dawgs;
1046  auto* updated_dawgs = new DawgPositionVector;
1047  DawgArgs dawg_args(&initial_dawgs, updated_dawgs, NO_PERM);
1048  bool word_start = false;
1049  if (uni_prev == nullptr) {
1050  // Starting from beginning of line.
1051  dict_->default_dawgs(&initial_dawgs, false);
1052  word_start = true;
1053  } else if (uni_prev->dawgs != nullptr) {
1054  // Continuing a previous dict word.
1055  dawg_args.active_dawgs = uni_prev->dawgs;
1056  word_start = uni_prev->start_of_dawg;
1057  } else {
1058  return; // Can't continue if not a dict word.
1059  }
1060  auto permuter = static_cast<PermuterType>(
1061  dict_->def_letter_is_okay(&dawg_args,
1062  dict_->getUnicharset(), unichar_id, false));
1063  if (permuter != NO_PERM) {
1064  PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false,
1065  word_start, dawg_args.valid_end, false, cert, prev,
1066  dawg_args.updated_dawgs, dawg_heap);
1067  if (dawg_args.valid_end && !space_delimited_) {
1068  // We can start another word right away, so push initial state as well,
1069  // to the dawg beam, and the regular character to the top choice beam,
1070  // since non-dict words can start here too.
1071  PushInitialDawgIfBetter(code, unichar_id, permuter, word_start, true,
1072  cert, cont, prev, step);
1073  PushHeapIfBetter(kBeamWidths[0], code, unichar_id, permuter, false,
1074  word_start, true, false, cert, prev, nullptr,
1075  nodawg_heap);
1076  }
1077  } else {
1078  delete updated_dawgs;
1079  }
1080 }
1081 
1082 // Adds a RecodeNode composed of the tuple (code, unichar_id,
1083 // initial-dawg-state, prev, cert) to the given heap if/ there is room or if
1084 // better than the current worst element if already full.
1085 void RecodeBeamSearch::PushInitialDawgIfBetter(int code, int unichar_id,
1086  PermuterType permuter,
1087  bool start, bool end, float cert,
1088  NodeContinuation cont,
1089  const RecodeNode* prev,
1090  RecodeBeam* step) {
1091  RecodeNode* best_initial_dawg = &step->best_initial_dawgs_[cont];
1092  float score = cert;
1093  if (prev != nullptr) score += prev->score;
1094  if (best_initial_dawg->code < 0 || score > best_initial_dawg->score) {
1095  auto* initial_dawgs = new DawgPositionVector;
1096  dict_->default_dawgs(initial_dawgs, false);
1097  RecodeNode node(code, unichar_id, permuter, true, start, end, false, cert,
1098  score, prev, initial_dawgs,
1099  ComputeCodeHash(code, false, prev));
1100  *best_initial_dawg = node;
1101  }
1102 }
1103 
1104 // Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
1105 // false, false, false, false, cert, prev, nullptr) to heap if there is room
1106 // or if better than the current worst element if already full.
1107 /* static */
1108 void RecodeBeamSearch::PushDupOrNoDawgIfBetter(
1109  int length, bool dup, int code, int unichar_id, float cert,
1110  float worst_dict_cert, float dict_ratio, bool use_dawgs,
1111  NodeContinuation cont, const RecodeNode* prev, RecodeBeam* step) {
1112  int index = BeamIndex(use_dawgs, cont, length);
1113  if (use_dawgs) {
1114  if (cert > worst_dict_cert) {
1115  PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1116  prev ? prev->permuter : NO_PERM, false, false, false,
1117  dup, cert, prev, nullptr, &step->beams_[index]);
1118  }
1119  } else {
1120  cert *= dict_ratio;
1121  if (cert >= kMinCertainty || code == null_char_) {
1122  PushHeapIfBetter(kBeamWidths[length], code, unichar_id,
1123  prev ? prev->permuter : TOP_CHOICE_PERM, false, false,
1124  false, dup, cert, prev, nullptr, &step->beams_[index]);
1125  }
1126  }
1127 }
1128 
1129 // Adds a RecodeNode composed of the tuple (code, unichar_id, permuter,
1130 // dawg_start, word_start, end, dup, cert, prev, d) to heap if there is room
1131 // or if better than the current worst element if already full.
1132 void RecodeBeamSearch::PushHeapIfBetter(int max_size, int code, int unichar_id,
1133  PermuterType permuter, bool dawg_start,
1134  bool word_start, bool end, bool dup,
1135  float cert, const RecodeNode* prev,
1136  DawgPositionVector* d,
1137  RecodeHeap* heap) {
1138  float score = cert;
1139  if (prev != nullptr) score += prev->score;
1140  if (heap->size() < max_size || score > heap->PeekTop().data.score) {
1141  uint64_t hash = ComputeCodeHash(code, dup, prev);
1142  RecodeNode node(code, unichar_id, permuter, dawg_start, word_start, end,
1143  dup, cert, score, prev, d, hash);
1144  if (UpdateHeapIfMatched(&node, heap)) return;
1145  RecodePair entry(score, node);
1146  heap->Push(&entry);
1147  ASSERT_HOST(entry.data.dawgs == nullptr);
1148  if (heap->size() > max_size) heap->Pop(&entry);
1149  } else {
1150  delete d;
1151  }
1152 }
1153 
1154 // Adds a RecodeNode to heap if there is room
1155 // or if better than the current worst element if already full.
1156 void RecodeBeamSearch::PushHeapIfBetter(int max_size, RecodeNode* node,
1157  RecodeHeap* heap) {
1158  if (heap->size() < max_size || node->score > heap->PeekTop().data.score) {
1159  if (UpdateHeapIfMatched(node, heap)) {
1160  return;
1161  }
1162  RecodePair entry(node->score, *node);
1163  heap->Push(&entry);
1164  ASSERT_HOST(entry.data.dawgs == nullptr);
1165  if (heap->size() > max_size) heap->Pop(&entry);
1166  }
1167 }
1168 
1169 // Searches the heap for a matching entry, and updates the score with
1170 // reshuffle if needed. Returns true if there was a match.
1171 bool RecodeBeamSearch::UpdateHeapIfMatched(RecodeNode* new_node,
1172  RecodeHeap* heap) {
1173  // TODO(rays) consider hash map instead of linear search.
1174  // It might not be faster because the hash map would have to be updated
1175  // every time a heap reshuffle happens, and that would be a lot of overhead.
1176  GenericVector<RecodePair>* nodes = heap->heap();
1177  for (int i = 0; i < nodes->size(); ++i) {
1178  RecodeNode& node = (*nodes)[i].data;
1179  if (node.code == new_node->code && node.code_hash == new_node->code_hash &&
1180  node.permuter == new_node->permuter &&
1181  node.start_of_dawg == new_node->start_of_dawg) {
1182  if (new_node->score > node.score) {
1183  // The new one is better. Update the entire node in the heap and
1184  // reshuffle.
1185  node = *new_node;
1186  (*nodes)[i].key = node.score;
1187  heap->Reshuffle(&(*nodes)[i]);
1188  }
1189  return true;
1190  }
1191  }
1192  return false;
1193 }
1194 
1195 // Computes and returns the code-hash for the given code and prev.
1196 uint64_t RecodeBeamSearch::ComputeCodeHash(int code, bool dup,
1197  const RecodeNode* prev) const {
1198  uint64_t hash = prev == nullptr ? 0 : prev->code_hash;
1199  if (!dup && code != null_char_) {
1200  int num_classes = recoder_.code_range();
1201  uint64_t carry = (((hash >> 32) * num_classes) >> 32);
1202  hash *= num_classes;
1203  hash += carry;
1204  hash += code;
1205  }
1206  return hash;
1207 }
1208 
1209 // Backtracks to extract the best path through the lattice that was built
1210 // during Decode. On return the best_nodes vector essentially contains the set
1211 // of code, score pairs that make the optimal path with the constraint that
1212 // the recoder can decode the code sequence back to a sequence of unichar-ids.
1213 void RecodeBeamSearch::ExtractBestPaths(
1215  GenericVector<const RecodeNode*>* second_nodes) const {
1216  // Scan both beams to extract the best and second best paths.
1217  const RecodeNode* best_node = nullptr;
1218  const RecodeNode* second_best_node = nullptr;
1219  const RecodeBeam* last_beam = beam_[beam_size_ - 1];
1220  for (int c = 0; c < NC_COUNT; ++c) {
1221  if (c == NC_ONLY_DUP) continue;
1222  auto cont = static_cast<NodeContinuation>(c);
1223  for (int is_dawg = 0; is_dawg < 2; ++is_dawg) {
1224  int beam_index = BeamIndex(is_dawg, cont, 0);
1225  int heap_size = last_beam->beams_[beam_index].size();
1226  for (int h = 0; h < heap_size; ++h) {
1227  const RecodeNode* node = &last_beam->beams_[beam_index].get(h).data;
1228  if (is_dawg) {
1229  // dawg_node may be a null_char, or duplicate, so scan back to the
1230  // last valid unichar_id.
1231  const RecodeNode* dawg_node = node;
1232  while (dawg_node != nullptr &&
1233  (dawg_node->unichar_id == INVALID_UNICHAR_ID ||
1234  dawg_node->duplicate))
1235  dawg_node = dawg_node->prev;
1236  if (dawg_node == nullptr ||
1237  (!dawg_node->end_of_word &&
1238  dawg_node->unichar_id != UNICHAR_SPACE)) {
1239  // Dawg node is not valid.
1240  continue;
1241  }
1242  }
1243  if (best_node == nullptr || node->score > best_node->score) {
1244  second_best_node = best_node;
1245  best_node = node;
1246  } else if (second_best_node == nullptr ||
1247  node->score > second_best_node->score) {
1248  second_best_node = node;
1249  }
1250  }
1251  }
1252  }
1253  if (second_nodes != nullptr) ExtractPath(second_best_node, second_nodes);
1254  ExtractPath(best_node, best_nodes);
1255 }
1256 
1257 // Helper backtracks through the lattice from the given node, storing the
1258 // path and reversing it.
1259 void RecodeBeamSearch::ExtractPath(
1260  const RecodeNode* node, GenericVector<const RecodeNode*>* path) const {
1261  path->truncate(0);
1262  while (node != nullptr) {
1263  path->push_back(node);
1264  node = node->prev;
1265  }
1266  path->reverse();
1267 }
1268 
1269 void RecodeBeamSearch::ExtractPath(
1270  const RecodeNode* node, GenericVector<const RecodeNode*>* path,
1271  int limiter) const {
1272  int pathcounter = 0;
1273  path->truncate(0);
1274  while (node != nullptr && pathcounter < limiter) {
1275  path->push_back(node);
1276  node = node->prev;
1277  ++pathcounter;
1278  }
1279  path->reverse();
1280 }
1281 
1282 // Helper prints debug information on the given lattice path.
1283 void RecodeBeamSearch::DebugPath(
1284  const UNICHARSET* unicharset,
1285  const GenericVector<const RecodeNode*>& path) const {
1286  for (int c = 0; c < path.size(); ++c) {
1287  const RecodeNode& node = *path[c];
1288  tprintf("%d ", c);
1289  node.Print(null_char_, *unicharset, 1);
1290  }
1291 }
1292 
1293 // Helper prints debug information on the given unichar path.
1294 void RecodeBeamSearch::DebugUnicharPath(
1295  const UNICHARSET* unicharset, const GenericVector<const RecodeNode*>& path,
1296  const GenericVector<int>& unichar_ids, const GenericVector<float>& certs,
1297  const GenericVector<float>& ratings,
1298  const GenericVector<int>& xcoords) const {
1299  int num_ids = unichar_ids.size();
1300  double total_rating = 0.0;
1301  for (int c = 0; c < num_ids; ++c) {
1302  int coord = xcoords[c];
1303  tprintf("%d %d=%s r=%g, c=%g, s=%d, e=%d, perm=%d\n", coord, unichar_ids[c],
1304  unicharset->debug_str(unichar_ids[c]).c_str(), ratings[c],
1305  certs[c], path[coord]->start_of_word, path[coord]->end_of_word,
1306  path[coord]->permuter);
1307  total_rating += ratings[c];
1308  }
1309  tprintf("Path total rating = %g\n", total_rating);
1310 }
1311 
1312 } // namespace tesseract.
tesseract::GenericHeap< RecodePair >
tesseract::NC_NO_DUP
Definition: recodebeam.h:78
tesseract::RecodeBeamSearch::kMinCertainty
static constexpr float kMinCertainty
Definition: recodebeam.h:252
WERD_RES::FakeWordFromRatings
void FakeWordFromRatings(PermuterType permuter)
Definition: pageres.cpp:894
tesseract::RecodeBeamSearch::excludedUnichars
std::vector< std::unordered_set< int > > excludedUnichars
Definition: recodebeam.h:245
C_BLOB::FakeBlob
static C_BLOB * FakeBlob(const TBOX &box)
Definition: stepblob.cpp:236
tesseract::TN_TOP2
Definition: recodebeam.h:86
pageres.h
tesseract::RecodeBeamSearch::ctc_choices
std::vector< std::vector< std::pair< const char *, float > > > ctc_choices
Definition: recodebeam.h:243
tesseract::RecodeNode::unichar_id
int unichar_id
Definition: recodebeam.h:144
networkio.h
tesseract::TopNState
TopNState
Definition: recodebeam.h:85
UNICHARSET::id_to_unichar_ext
const char * id_to_unichar_ext(UNICHAR_ID id) const
Definition: unicharset.cpp:298
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:87
PermuterType
PermuterType
Definition: ratngs.h:230
tesseract::NodeContinuation
NodeContinuation
Definition: recodebeam.h:73
ICOORD
integer coordinate
Definition: points.h:30
tesseract::TN_ALSO_RAN
Definition: recodebeam.h:88
MATRIX
Definition: matrix.h:574
tesseract::PointerVector< WERD_RES >
tesseract::RecodeBeamSearch::LengthFromBeamsIndex
static int LengthFromBeamsIndex(int index)
Definition: recodebeam.h:259
NO_PERM
Definition: ratngs.h:231
TBOX::top
int16_t top() const
Definition: rect.h:57
tesseract::NetworkIO::Width
int Width() const
Definition: networkio.h:107
recodebeam.h
tesseract::UnicharCompress::code_range
int code_range() const
Definition: unicharcompress.h:161
WERD_RES
Definition: pageres.h:160
tesseract::RecodeBeamSearch::timesteps
std::vector< std::vector< std::pair< const char *, float > > > timesteps
Definition: recodebeam.h:239
tesseract::PointerVector::truncate
void truncate(int size)
Definition: genericvector.h:457
tesseract::PointerVector::clear
void clear()
Definition: genericvector.h:490
tesseract::RecodeNode
Definition: recodebeam.h:93
UNICHARSET::IsSpaceDelimited
bool IsSpaceDelimited(UNICHAR_ID unichar_id) const
Definition: unicharset.h:642
tesseract::RecodeBeamSearch::Decode
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:78
WERD_RES::ratings
MATRIX * ratings
Definition: pageres.h:231
tesseract::Dict::default_dawgs
void default_dawgs(DawgPositionVector *anylength_dawgs, bool suppress_patterns) const
Definition: dict.cpp:617
tesseract::RecodeNode::start_of_dawg
bool start_of_dawg
Definition: recodebeam.h:151
GENERIC_2D_ARRAY< float >
tesseract::RecodeNode::prev
const RecodeNode * prev
Definition: recodebeam.h:168
tesseract::RecodeBeamSearch::PrintBeam2
void PrintBeam2(bool uids, int num_outputs, const UNICHARSET *charset, bool secondary) const
Definition: recodebeam.cpp:315
GenericVector::back
T & back() const
Definition: genericvector.h:728
TBOX::height
int16_t height() const
Definition: rect.h:107
GenericVector::reverse
void reverse()
Definition: genericvector.h:215
tesseract::GenericHeap::size
int size() const
Definition: genericheap.h:71
tesseract::RecodeNode::score
float score
Definition: recodebeam.h:166
tesseract::TN_TOPN
Definition: recodebeam.h:87
tesseract::greater_than::operator()
bool operator()(const RecodeNode *&node1, const RecodeNode *&node2)
Definition: recodebeam.cpp:309
GenericVector::push_back
int push_back(T object)
Definition: genericvector.h:799
tesseract::RecodeBeamSearch::extractSymbolChoices
void extractSymbolChoices(const UNICHARSET *unicharset)
Definition: recodebeam.cpp:395
UNICHARSET::debug_str
STRING debug_str(UNICHAR_ID id) const
Definition: unicharset.cpp:342
tesseract::UnicharCompress::DecodeUnichar
int DecodeUnichar(const RecodedCharID &code) const
Definition: unicharcompress.cpp:291
tesseract::RecodeNode::code_hash
uint64_t code_hash
Definition: recodebeam.h:173
GENERIC_2D_ARRAY::dim2
int dim2() const
Definition: matrix.h:206
STRING::c_str
const char * c_str() const
Definition: strngs.cpp:192
UNICHARSET::get_enabled
bool get_enabled(UNICHAR_ID unichar_id) const
Definition: unicharset.h:868
UNICHAR_SPACE
Definition: unicharset.h:34
tesseract::TN_COUNT
Definition: recodebeam.h:89
tesseract::NetworkIO::f
float * f(int t)
Definition: networkio.h:115
tesseract::KDPair::data
Data data
Definition: kdpair.h:45
tesseract::NC_ONLY_DUP
Definition: recodebeam.h:75
tesseract::Dict::def_letter_is_okay
int def_letter_is_okay(void *void_dawg_args, const UNICHARSET &unicharset, UNICHAR_ID unichar_id, bool word_end) const
Definition: dict.cpp:395
GenericVector::empty
bool empty() const
Definition: genericvector.h:86
UNICHARSET
Definition: unicharset.h:145
tesseract::NetworkIO
Definition: networkio.h:39
TBOX::bottom
int16_t bottom() const
Definition: rect.h:64
tesseract::RecodeBeamSearch::IsDawgFromBeamsIndex
static bool IsDawgFromBeamsIndex(int index)
Definition: recodebeam.h:263
tesseract::NC_COUNT
Definition: recodebeam.h:81
character
Definition: mfoutline.h:62
tesseract
Definition: baseapi.h:65
null_char_
int null_char_
Definition: unicharcompress_test.cc:168
tesseract::RecodeBeamSearch::DecodeSecondaryBeams
void DecodeSecondaryBeams(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
Definition: recodebeam.cpp:105
TOP_CHOICE_PERM
Definition: ratngs.h:233
tesseract::RecodeBeamSearch::kNumBeams
static const int kNumBeams
Definition: recodebeam.h:257
tesseract::RecodeBeamSearch::segmentTimestepsByCharacters
void segmentTimestepsByCharacters()
Definition: recodebeam.cpp:156
tesseract::RecodeBeamSearch::BeamIndex
static int BeamIndex(bool is_dawg, NodeContinuation cont, int length)
Definition: recodebeam.h:267
GenericVector< int >
tesseract::RecodeBeamSearch::ContinuationFromBeamsIndex
static NodeContinuation ContinuationFromBeamsIndex(int index)
Definition: recodebeam.h:260
tesseract::Dict
Definition: dict.h:91
tesseract::NetworkIO::ProbToCertainty
static float ProbToCertainty(float prob)
Definition: networkio.cpp:568
tesseract::UnicharCompress::GetFinalCodes
const GenericVector< int > * GetFinalCodes(const RecodedCharID &code) const
Definition: unicharcompress.h:179
BLOB_CHOICE
Definition: ratngs.h:49
tesseract::RecodeNode::start_of_word
bool start_of_word
Definition: recodebeam.h:153
tesseract::RecodedCharID::kMaxCodeLen
static const int kMaxCodeLen
Definition: unicharcompress.h:37
WERD
Definition: werd.h:55
tesseract::RecodeNode::certainty
float certainty
Definition: recodebeam.h:164
GenericVector::truncate
void truncate(int size)
Definition: genericvector.h:132
TBOX::left
int16_t left() const
Definition: rect.h:71
tesseract::RecodeNode::end_of_word
bool end_of_word
Definition: recodebeam.h:157
GenericVector::get
T & get(int index) const
Definition: genericvector.h:716
tesseract::RecodeBeamSearch::combineSegmentedTimesteps
std::vector< std::vector< std::pair< const char *, float > > > combineSegmentedTimesteps(std::vector< std::vector< std::vector< std::pair< const char *, float >>>> *segmentedTimesteps)
Definition: recodebeam.cpp:166
unicharcompress.h
tesseract::RecodeNode::permuter
PermuterType permuter
Definition: recodebeam.h:148
GenericVector::init_to_size
void init_to_size(int size, const T &t)
Definition: genericvector.h:706
GENERIC_2D_ARRAY::put
void put(ICOORD pos, const T &thing)
Definition: matrix.h:219
tesseract::UnicharCompress::GetNextCodes
const GenericVector< int > * GetNextCodes(const RecodedCharID &code) const
Definition: unicharcompress.h:173
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
tesseract::RecodeBeamSearch::ExtractBestPathAsWords
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words, int lstm_choice_mode=0)
Definition: recodebeam.cpp:230
tesseract::RecodeBeamSearch::DebugBeams
void DebugBeams(const UNICHARSET &unicharset) const
Definition: recodebeam.cpp:495
tesseract::Dict::getUnicharset
const UNICHARSET & getUnicharset() const
Definition: dict.h:101
tesseract::RecodeBeamSearch::character_boundaries_
std::vector< int > character_boundaries_
Definition: recodebeam.h:247
tesseract::greater_than
Definition: recodebeam.cpp:308
tesseract::RecodeHeap
GenericHeap< RecodePair > RecodeHeap
Definition: recodebeam.h:177
tesseract::Dict::IsSpaceDelimitedLang
bool IsSpaceDelimitedLang() const
Returns true if the language is space-delimited (not CJ, or T).
Definition: dict.cpp:883
UNICHARSET::id_to_unichar
const char * id_to_unichar(UNICHAR_ID id) const
Definition: unicharset.cpp:290
tesseract::NetworkIO::NumFeatures
int NumFeatures() const
Definition: networkio.h:111
tesseract::RecodeBeamSearch::ExtractBestPathAsUnicharIds
void ExtractBestPathAsUnicharIds(bool debug, const UNICHARSET *unicharset, GenericVector< int > *unichar_ids, GenericVector< float > *certs, GenericVector< float > *ratings, GenericVector< int > *xcoords) const
Definition: recodebeam.cpp:215
tesseract::RecodePair
KDPairInc< double, RecodeNode > RecodePair
Definition: recodebeam.h:176
tesseract::RecodeBeamSearch::RecodeBeamSearch
RecodeBeamSearch(const UnicharCompress &recoder, int null_char, bool simple_text, Dict *dict)
Definition: recodebeam.cpp:63
tesseract::RecodeNode::Print
void Print(int null_char, const UNICHARSET &unicharset, int depth) const
Definition: recodebeam.cpp:43
tesseract::UnicharCompress
Definition: unicharcompress.h:128
GenericVector::size
int size() const
Definition: genericvector.h:71
tesseract::GenericHeap::get
const Pair & get(int index) const
Definition: genericheap.h:87
UNICHARSET::size
int size() const
Definition: unicharset.h:341
tesseract::NC_ANYTHING
Definition: recodebeam.h:74
TBOX::scale
void scale(const float f)
Definition: rect.h:174
GENERIC_2D_ARRAY::dim1
int dim1() const
Definition: matrix.h:205
tesseract::RecodeBeamSearch::ExtractBestPathAsLabels
void ExtractBestPathAsLabels(GenericVector< int > *labels, GenericVector< int > *xcoords) const
Definition: recodebeam.cpp:192
tesseract::RecodeBeamSearch::segmentedTimesteps
std::vector< std::vector< std::vector< std::pair< const char *, float > > > > segmentedTimesteps
Definition: recodebeam.h:241
tesseract::RecodeNode::code
int code
Definition: recodebeam.h:142
BCC_STATIC_CLASSIFIER
Definition: ratngs.h:42
TBOX
Definition: rect.h:33