All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
errorcounter.cpp
Go to the documentation of this file.
1 // Copyright 2011 Google Inc. All Rights Reserved.
2 // Author: rays@google.com (Ray Smith)
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 // http://www.apache.org/licenses/LICENSE-2.0
8 // Unless required by applicable law or agreed to in writing, software
9 // distributed under the License is distributed on an "AS IS" BASIS,
10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 // See the License for the specific language governing permissions and
12 // limitations under the License.
13 //
15 #include <ctime>
16 
17 #include "errorcounter.h"
18 
19 #include "fontinfo.h"
20 #include "ndminx.h"
21 #include "sampleiterator.h"
22 #include "shapeclassifier.h"
23 #include "shapetable.h"
24 #include "trainingsample.h"
25 #include "trainingsampleset.h"
26 #include "unicity_table.h"
27 
28 namespace tesseract {
29 
30 // Difference in result rating to be thought of as an "equal" choice.
31 const double kRatingEpsilon = 1.0 / 32;
32 
33 // Tests a classifier, computing its error rate.
34 // See errorcounter.h for description of arguments.
35 // Iterates over the samples, calling the classifier in normal/silent mode.
36 // If the classifier makes a CT_UNICHAR_TOPN_ERR error, and the appropriate
37 // report_level is set (4 or greater), it will then call the classifier again
38 // with a debug flag and a keep_this argument to find out what is going on.
40  int report_level, CountTypes boosting_mode,
41  const FontInfoTable& fontinfo_table,
42  const GenericVector<Pix*>& page_images, SampleIterator* it,
43  double* unichar_error, double* scaled_error, STRING* fonts_report) {
44  int fontsize = it->sample_set()->NumFonts();
45  ErrorCounter counter(classifier->GetUnicharset(), fontsize);
47 
48  clock_t start = clock();
49  int total_samples = 0;
50  double unscaled_error = 0.0;
51  // Set a number of samples on which to run the classify debug mode.
52  int error_samples = report_level > 3 ? report_level * report_level : 0;
53  // Iterate over all the samples, accumulating errors.
54  for (it->Begin(); !it->AtEnd(); it->Next()) {
55  TrainingSample* mutable_sample = it->MutableSample();
56  int page_index = mutable_sample->page_num();
57  Pix* page_pix = 0 <= page_index && page_index < page_images.size()
58  ? page_images[page_index] : NULL;
59  // No debug, no keep this.
60  classifier->UnicharClassifySample(*mutable_sample, page_pix, 0,
61  INVALID_UNICHAR_ID, &results);
62  bool debug_it = false;
63  int correct_id = mutable_sample->class_id();
64  if (counter.unicharset_.has_special_codes() &&
65  (correct_id == UNICHAR_SPACE || correct_id == UNICHAR_JOINED ||
66  correct_id == UNICHAR_BROKEN)) {
67  // This is junk so use the special counter.
68  debug_it = counter.AccumulateJunk(report_level > 3,
69  results,
70  mutable_sample);
71  } else {
72  debug_it = counter.AccumulateErrors(report_level > 3, boosting_mode,
73  fontinfo_table,
74  results, mutable_sample);
75  }
76  if (debug_it && error_samples > 0) {
77  // Running debug, keep the correct answer, and debug the classifier.
78  tprintf("Error on sample %d: %s Classifier debug output:\n",
79  it->GlobalSampleIndex(),
80  it->sample_set()->SampleToString(*mutable_sample).string());
81  classifier->DebugDisplay(*mutable_sample, page_pix, correct_id);
82  --error_samples;
83  }
84  ++total_samples;
85  }
86  double total_time = 1.0 * (clock() - start) / CLOCKS_PER_SEC;
87  // Create the appropriate error report.
88  unscaled_error = counter.ReportErrors(report_level, boosting_mode,
89  fontinfo_table,
90  *it, unichar_error, fonts_report);
91  if (scaled_error != NULL) *scaled_error = counter.scaled_error_;
92  if (report_level > 1) {
93  // It is useful to know the time in microseconds/char.
94  tprintf("Errors computed in %.2fs at %.1f μs/char\n",
95  total_time, 1000000.0 * total_time / total_samples);
96  }
97  return unscaled_error;
98 }
99 
100 // Tests a pair of classifiers, debugging errors of the new against the old.
101 // See errorcounter.h for description of arguments.
102 // Iterates over the samples, calling the classifiers in normal/silent mode.
103 // If the new_classifier makes a boosting_mode error that the old_classifier
104 // does not, it will then call the new_classifier again with a debug flag
105 // and a keep_this argument to find out what is going on.
107  ShapeClassifier* new_classifier, ShapeClassifier* old_classifier,
108  CountTypes boosting_mode,
109  const FontInfoTable& fontinfo_table,
110  const GenericVector<Pix*>& page_images, SampleIterator* it) {
111  int fontsize = it->sample_set()->NumFonts();
112  ErrorCounter old_counter(old_classifier->GetUnicharset(), fontsize);
113  ErrorCounter new_counter(new_classifier->GetUnicharset(), fontsize);
115 
116  int total_samples = 0;
117  int error_samples = 25;
118  int total_new_errors = 0;
119  // Iterate over all the samples, accumulating errors.
120  for (it->Begin(); !it->AtEnd(); it->Next()) {
121  TrainingSample* mutable_sample = it->MutableSample();
122  int page_index = mutable_sample->page_num();
123  Pix* page_pix = 0 <= page_index && page_index < page_images.size()
124  ? page_images[page_index] : NULL;
125  // No debug, no keep this.
126  old_classifier->UnicharClassifySample(*mutable_sample, page_pix, 0,
127  INVALID_UNICHAR_ID, &results);
128  int correct_id = mutable_sample->class_id();
129  if (correct_id != 0 &&
130  !old_counter.AccumulateErrors(true, boosting_mode, fontinfo_table,
131  results, mutable_sample)) {
132  // old classifier was correct, check the new one.
133  new_classifier->UnicharClassifySample(*mutable_sample, page_pix, 0,
134  INVALID_UNICHAR_ID, &results);
135  if (correct_id != 0 &&
136  new_counter.AccumulateErrors(true, boosting_mode, fontinfo_table,
137  results, mutable_sample)) {
138  tprintf("New Error on sample %d: Classifier debug output:\n",
139  it->GlobalSampleIndex());
140  ++total_new_errors;
141  new_classifier->UnicharClassifySample(*mutable_sample, page_pix, 1,
142  correct_id, &results);
143  if (results.size() > 0 && error_samples > 0) {
144  new_classifier->DebugDisplay(*mutable_sample, page_pix, correct_id);
145  --error_samples;
146  }
147  }
148  }
149  ++total_samples;
150  }
151  tprintf("Total new errors = %d\n", total_new_errors);
152 }
153 
154 // Constructor is private. Only anticipated use of ErrorCounter is via
155 // the static ComputeErrorRate.
156 ErrorCounter::ErrorCounter(const UNICHARSET& unicharset, int fontsize)
157  : scaled_error_(0.0), rating_epsilon_(kRatingEpsilon),
158  unichar_counts_(unicharset.size(), unicharset.size(), 0),
159  ok_score_hist_(0, 101), bad_score_hist_(0, 101),
160  unicharset_(unicharset) {
161  Counts empty_counts;
162  font_counts_.init_to_size(fontsize, empty_counts);
163  multi_unichar_counts_.init_to_size(unicharset.size(), 0);
164 }
165 ErrorCounter::~ErrorCounter() {
166 }
167 
168 // Accumulates the errors from the classifier results on a single sample.
169 // Returns true if debug is true and a CT_UNICHAR_TOPN_ERR error occurred.
170 // boosting_mode selects the type of error to be used for boosting and the
171 // is_error_ member of sample is set according to whether the required type
172 // of error occurred. The font_table provides access to font properties
173 // for error counting and shape_table is used to understand the relationship
174 // between unichar_ids and shape_ids in the results
175 bool ErrorCounter::AccumulateErrors(bool debug, CountTypes boosting_mode,
176  const FontInfoTable& font_table,
177  const GenericVector<UnicharRating>& results,
178  TrainingSample* sample) {
179  int num_results = results.size();
180  int answer_actual_rank = -1;
181  int font_id = sample->font_id();
182  int unichar_id = sample->class_id();
183  sample->set_is_error(false);
184  if (num_results == 0) {
185  // Reject. We count rejects as a separate category, but still mark the
186  // sample as an error in case any training module wants to use that to
187  // improve the classifier.
188  sample->set_is_error(true);
189  ++font_counts_[font_id].n[CT_REJECT];
190  } else {
191  // Find rank of correct unichar answer, using rating_epsilon_ to allow
192  // different answers to score as equal. (Ignoring the font.)
193  int epsilon_rank = 0;
194  int answer_epsilon_rank = -1;
195  int num_top_answers = 0;
196  double prev_rating = results[0].rating;
197  bool joined = false;
198  bool broken = false;
199  int res_index = 0;
200  while (res_index < num_results) {
201  if (results[res_index].rating < prev_rating - rating_epsilon_) {
202  ++epsilon_rank;
203  prev_rating = results[res_index].rating;
204  }
205  if (results[res_index].unichar_id == unichar_id &&
206  answer_epsilon_rank < 0) {
207  answer_epsilon_rank = epsilon_rank;
208  answer_actual_rank = res_index;
209  }
210  if (results[res_index].unichar_id == UNICHAR_JOINED &&
211  unicharset_.has_special_codes())
212  joined = true;
213  else if (results[res_index].unichar_id == UNICHAR_BROKEN &&
214  unicharset_.has_special_codes())
215  broken = true;
216  else if (epsilon_rank == 0)
217  ++num_top_answers;
218  ++res_index;
219  }
220  if (answer_actual_rank != 0) {
221  // Correct result is not absolute top.
222  ++font_counts_[font_id].n[CT_UNICHAR_TOPTOP_ERR];
223  if (boosting_mode == CT_UNICHAR_TOPTOP_ERR) sample->set_is_error(true);
224  }
225  if (answer_epsilon_rank == 0) {
226  ++font_counts_[font_id].n[CT_UNICHAR_TOP_OK];
227  // Unichar OK, but count if multiple unichars.
228  if (num_top_answers > 1) {
229  ++font_counts_[font_id].n[CT_OK_MULTI_UNICHAR];
230  ++multi_unichar_counts_[unichar_id];
231  }
232  // Check to see if any font in the top choice has attributes that match.
233  // TODO(rays) It is easy to add counters for individual font attributes
234  // here if we want them.
235  if (font_table.SetContainsFontProperties(
236  font_id, results[answer_actual_rank].fonts)) {
237  // Font attributes were matched.
238  // Check for multiple properties.
239  if (font_table.SetContainsMultipleFontProperties(
240  results[answer_actual_rank].fonts))
241  ++font_counts_[font_id].n[CT_OK_MULTI_FONT];
242  } else {
243  // Font attributes weren't matched.
244  ++font_counts_[font_id].n[CT_FONT_ATTR_ERR];
245  }
246  } else {
247  // This is a top unichar error.
248  ++font_counts_[font_id].n[CT_UNICHAR_TOP1_ERR];
249  if (boosting_mode == CT_UNICHAR_TOP1_ERR) sample->set_is_error(true);
250  // Count maps from unichar id to wrong unichar id.
251  ++unichar_counts_(unichar_id, results[0].unichar_id);
252  if (answer_epsilon_rank < 0 || answer_epsilon_rank >= 2) {
253  // It is also a 2nd choice unichar error.
254  ++font_counts_[font_id].n[CT_UNICHAR_TOP2_ERR];
255  if (boosting_mode == CT_UNICHAR_TOP2_ERR) sample->set_is_error(true);
256  }
257  if (answer_epsilon_rank < 0) {
258  // It is also a top-n choice unichar error.
259  ++font_counts_[font_id].n[CT_UNICHAR_TOPN_ERR];
260  if (boosting_mode == CT_UNICHAR_TOPN_ERR) sample->set_is_error(true);
261  answer_epsilon_rank = epsilon_rank;
262  }
263  }
264  // Compute mean number of return values and mean rank of correct answer.
265  font_counts_[font_id].n[CT_NUM_RESULTS] += num_results;
266  font_counts_[font_id].n[CT_RANK] += answer_epsilon_rank;
267  if (joined)
268  ++font_counts_[font_id].n[CT_OK_JOINED];
269  if (broken)
270  ++font_counts_[font_id].n[CT_OK_BROKEN];
271  }
272  // If it was an error for boosting then sum the weight.
273  if (sample->is_error()) {
274  scaled_error_ += sample->weight();
275  if (debug) {
276  tprintf("%d results for char %s font %d :",
277  num_results, unicharset_.id_to_unichar(unichar_id),
278  font_id);
279  for (int i = 0; i < num_results; ++i) {
280  tprintf(" %.3f : %s\n",
281  results[i].rating,
282  unicharset_.id_to_unichar(results[i].unichar_id));
283  }
284  return true;
285  }
286  int percent = 0;
287  if (num_results > 0)
288  percent = IntCastRounded(results[0].rating * 100);
289  bad_score_hist_.add(percent, 1);
290  } else {
291  int percent = 0;
292  if (answer_actual_rank >= 0)
293  percent = IntCastRounded(results[answer_actual_rank].rating * 100);
294  ok_score_hist_.add(percent, 1);
295  }
296  return false;
297 }
298 
299 // Accumulates counts for junk. Counts only whether the junk was correctly
300 // rejected or not.
301 bool ErrorCounter::AccumulateJunk(bool debug,
302  const GenericVector<UnicharRating>& results,
303  TrainingSample* sample) {
304  // For junk we accept no answer, or an explicit shape answer matching the
305  // class id of the sample.
306  int num_results = results.size();
307  int font_id = sample->font_id();
308  int unichar_id = sample->class_id();
309  int percent = 0;
310  if (num_results > 0)
311  percent = IntCastRounded(results[0].rating * 100);
312  if (num_results > 0 && results[0].unichar_id != unichar_id) {
313  // This is a junk error.
314  ++font_counts_[font_id].n[CT_ACCEPTED_JUNK];
315  sample->set_is_error(true);
316  // It counts as an error for boosting too so sum the weight.
317  scaled_error_ += sample->weight();
318  bad_score_hist_.add(percent, 1);
319  return debug;
320  } else {
321  // Correctly rejected.
322  ++font_counts_[font_id].n[CT_REJECTED_JUNK];
323  sample->set_is_error(false);
324  ok_score_hist_.add(percent, 1);
325  }
326  return false;
327 }
328 
329 // Creates a report of the error rate. The report_level controls the detail
330 // that is reported to stderr via tprintf:
331 // 0 -> no output.
332 // >=1 -> bottom-line error rate.
333 // >=3 -> font-level error rate.
334 // boosting_mode determines the return value. It selects which (un-weighted)
335 // error rate to return.
336 // The fontinfo_table from MasterTrainer provides the names of fonts.
337 // The it determines the current subset of the training samples.
338 // If not NULL, the top-choice unichar error rate is saved in unichar_error.
339 // If not NULL, the report string is saved in fonts_report.
340 // (Ignoring report_level).
341 double ErrorCounter::ReportErrors(int report_level, CountTypes boosting_mode,
342  const FontInfoTable& fontinfo_table,
343  const SampleIterator& it,
344  double* unichar_error,
345  STRING* fonts_report) {
346  // Compute totals over all the fonts and report individual font results
347  // when required.
348  Counts totals;
349  int fontsize = font_counts_.size();
350  for (int f = 0; f < fontsize; ++f) {
351  // Accumulate counts over fonts.
352  totals += font_counts_[f];
353  STRING font_report;
354  if (ReportString(false, font_counts_[f], &font_report)) {
355  if (fonts_report != NULL) {
356  *fonts_report += fontinfo_table.get(f).name;
357  *fonts_report += ": ";
358  *fonts_report += font_report;
359  *fonts_report += "\n";
360  }
361  if (report_level > 2) {
362  // Report individual font error rates.
363  tprintf("%s: %s\n", fontinfo_table.get(f).name, font_report.string());
364  }
365  }
366  }
367  // Report the totals.
368  STRING total_report;
369  bool any_results = ReportString(true, totals, &total_report);
370  if (fonts_report != NULL && fonts_report->length() == 0) {
371  // Make sure we return something even if there were no samples.
372  *fonts_report = "NoSamplesFound: ";
373  *fonts_report += total_report;
374  *fonts_report += "\n";
375  }
376  if (report_level > 0) {
377  // Report the totals.
378  STRING total_report;
379  if (any_results) {
380  tprintf("TOTAL Scaled Err=%.4g%%, %s\n",
381  scaled_error_ * 100.0, total_report.string());
382  }
383  // Report the worst substitution error only for now.
384  if (totals.n[CT_UNICHAR_TOP1_ERR] > 0) {
385  int charsetsize = unicharset_.size();
386  int worst_uni_id = 0;
387  int worst_result_id = 0;
388  int worst_err = 0;
389  for (int u = 0; u < charsetsize; ++u) {
390  for (int v = 0; v < charsetsize; ++v) {
391  if (unichar_counts_(u, v) > worst_err) {
392  worst_err = unichar_counts_(u, v);
393  worst_uni_id = u;
394  worst_result_id = v;
395  }
396  }
397  }
398  if (worst_err > 0) {
399  tprintf("Worst error = %d:%s -> %s with %d/%d=%.2f%% errors\n",
400  worst_uni_id, unicharset_.id_to_unichar(worst_uni_id),
401  unicharset_.id_to_unichar(worst_result_id),
402  worst_err, totals.n[CT_UNICHAR_TOP1_ERR],
403  100.0 * worst_err / totals.n[CT_UNICHAR_TOP1_ERR]);
404  }
405  }
406  tprintf("Multi-unichar shape use:\n");
407  for (int u = 0; u < multi_unichar_counts_.size(); ++u) {
408  if (multi_unichar_counts_[u] > 0) {
409  tprintf("%d multiple answers for unichar: %s\n",
410  multi_unichar_counts_[u],
411  unicharset_.id_to_unichar(u));
412  }
413  }
414  tprintf("OK Score histogram:\n");
415  ok_score_hist_.print();
416  tprintf("ERROR Score histogram:\n");
417  bad_score_hist_.print();
418  }
419 
420  double rates[CT_SIZE];
421  if (!ComputeRates(totals, rates))
422  return 0.0;
423  // Set output values if asked for.
424  if (unichar_error != NULL)
425  *unichar_error = rates[CT_UNICHAR_TOP1_ERR];
426  return rates[boosting_mode];
427 }
428 
429 // Sets the report string to a combined human and machine-readable report
430 // string of the error rates.
431 // Returns false if there is no data, leaving report unchanged, unless
432 // even_if_empty is true.
433 bool ErrorCounter::ReportString(bool even_if_empty, const Counts& counts,
434  STRING* report) {
435  // Compute the error rates.
436  double rates[CT_SIZE];
437  if (!ComputeRates(counts, rates) && !even_if_empty)
438  return false;
439  // Using %.4g%%, the length of the output string should exactly match the
440  // length of the format string, but in case of overflow, allow for +eddd
441  // on each number.
442  const int kMaxExtraLength = 5; // Length of +eddd.
443  // Keep this format string and the snprintf in sync with the CountTypes enum.
444  const char* format_str = "Unichar=%.4g%%[1], %.4g%%[2], %.4g%%[n], %.4g%%[T] "
445  "Mult=%.4g%%, Jn=%.4g%%, Brk=%.4g%%, Rej=%.4g%%, "
446  "FontAttr=%.4g%%, Multi=%.4g%%, "
447  "Answers=%.3g, Rank=%.3g, "
448  "OKjunk=%.4g%%, Badjunk=%.4g%%";
449  int max_str_len = strlen(format_str) + kMaxExtraLength * (CT_SIZE - 1) + 1;
450  char* formatted_str = new char[max_str_len];
451  snprintf(formatted_str, max_str_len, format_str,
452  rates[CT_UNICHAR_TOP1_ERR] * 100.0,
453  rates[CT_UNICHAR_TOP2_ERR] * 100.0,
454  rates[CT_UNICHAR_TOPN_ERR] * 100.0,
455  rates[CT_UNICHAR_TOPTOP_ERR] * 100.0,
456  rates[CT_OK_MULTI_UNICHAR] * 100.0,
457  rates[CT_OK_JOINED] * 100.0,
458  rates[CT_OK_BROKEN] * 100.0,
459  rates[CT_REJECT] * 100.0,
460  rates[CT_FONT_ATTR_ERR] * 100.0,
461  rates[CT_OK_MULTI_FONT] * 100.0,
462  rates[CT_NUM_RESULTS],
463  rates[CT_RANK],
464  100.0 * rates[CT_REJECTED_JUNK],
465  100.0 * rates[CT_ACCEPTED_JUNK]);
466  *report = formatted_str;
467  delete [] formatted_str;
468  // Now append each field of counts with a tab in front so the result can
469  // be loaded into a spreadsheet.
470  for (int ct = 0; ct < CT_SIZE; ++ct)
471  report->add_str_int("\t", counts.n[ct]);
472  return true;
473 }
474 
475 // Computes the error rates and returns in rates which is an array of size
476 // CT_SIZE. Returns false if there is no data, leaving rates unchanged.
477 bool ErrorCounter::ComputeRates(const Counts& counts, double rates[CT_SIZE]) {
478  int ok_samples = counts.n[CT_UNICHAR_TOP_OK] + counts.n[CT_UNICHAR_TOP1_ERR] +
479  counts.n[CT_REJECT];
480  int junk_samples = counts.n[CT_REJECTED_JUNK] + counts.n[CT_ACCEPTED_JUNK];
481  // Compute rates for normal chars.
482  double denominator = static_cast<double>(MAX(ok_samples, 1));
483  for (int ct = 0; ct <= CT_RANK; ++ct)
484  rates[ct] = counts.n[ct] / denominator;
485  // Compute rates for junk.
486  denominator = static_cast<double>(MAX(junk_samples, 1));
487  for (int ct = CT_REJECTED_JUNK; ct <= CT_ACCEPTED_JUNK; ++ct)
488  rates[ct] = counts.n[ct] / denominator;
489  return ok_samples != 0 || junk_samples != 0;
490 }
491 
492 ErrorCounter::Counts::Counts() {
493  memset(n, 0, sizeof(n[0]) * CT_SIZE);
494 }
495 // Adds other into this for computing totals.
496 void ErrorCounter::Counts::operator+=(const Counts& other) {
497  for (int ct = 0; ct < CT_SIZE; ++ct)
498  n[ct] += other.n[ct];
499 }
500 
501 
502 } // namespace tesseract.
503 
504 
505 
506 
507 
int size() const
Definition: genericvector.h:72
virtual int UnicharClassifySample(const TrainingSample &sample, Pix *page_pix, int debug, UNICHAR_ID keep_this, GenericVector< UnicharRating > *results)
#define MAX(x, y)
Definition: ndminx.h:24
virtual void DebugDisplay(const TrainingSample &sample, Pix *page_pix, UNICHAR_ID unichar_id)
#define tprintf(...)
Definition: tprintf.h:31
TrainingSample * MutableSample() const
void add(inT32 value, inT32 count)
Definition: statistc.cpp:104
inT32 length() const
Definition: strngs.cpp:188
static double ComputeErrorRate(ShapeClassifier *classifier, int report_level, CountTypes boosting_mode, const FontInfoTable &fontinfo_table, const GenericVector< Pix * > &page_images, SampleIterator *it, double *unichar_error, double *scaled_error, STRING *fonts_report)
const double kRatingEpsilon
const char *const id_to_unichar(UNICHAR_ID id) const
Definition: unicharset.cpp:266
void init_to_size(int size, T t)
bool has_special_codes() const
Definition: unicharset.h:670
void add_str_int(const char *str, int number)
Definition: strngs.cpp:376
UNICHAR_ID class_id() const
Definition: cluster.h:32
int IntCastRounded(double x)
Definition: helpers.h:172
static void DebugNewErrors(ShapeClassifier *new_classifier, ShapeClassifier *old_classifier, CountTypes boosting_mode, const FontInfoTable &fontinfo_table, const GenericVector< Pix * > &page_images, SampleIterator *it)
const TrainingSampleSet * sample_set() const
Definition: strngs.h:44
#define NULL
Definition: host.h:144
void print() const
Definition: statistc.cpp:538
STRING SampleToString(const TrainingSample &sample) const
int size() const
Definition: unicharset.h:297
const char * string() const
Definition: strngs.cpp:193
ICOORD & operator+=(ICOORD &op1, const ICOORD &op2)
Definition: ipoints.h:86
virtual const UNICHARSET & GetUnicharset() const