All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Modules Pages
mastertrainer.h
Go to the documentation of this file.
1 // Copyright 2010 Google Inc. All Rights Reserved.
2 // Author: rays@google.com (Ray Smith)
4 // File: mastertrainer.h
5 // Description: Trainer to build the MasterClassifier.
6 // Author: Ray Smith
7 // Created: Wed Nov 03 18:07:01 PDT 2010
8 //
9 // (C) Copyright 2010, Google Inc.
10 // Licensed under the Apache License, Version 2.0 (the "License");
11 // you may not use this file except in compliance with the License.
12 // You may obtain a copy of the License at
13 // http://www.apache.org/licenses/LICENSE-2.0
14 // Unless required by applicable law or agreed to in writing, software
15 // distributed under the License is distributed on an "AS IS" BASIS,
16 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 // See the License for the specific language governing permissions and
18 // limitations under the License.
19 //
21 
22 #ifndef TESSERACT_TRAINING_MASTERTRAINER_H__
23 #define TESSERACT_TRAINING_MASTERTRAINER_H__
24 
28 #include "classify.h"
29 #include "cluster.h"
30 #include "intfx.h"
31 #include "elst.h"
32 #include "errorcounter.h"
33 #include "featdefs.h"
34 #include "fontinfo.h"
35 #include "indexmapbidi.h"
36 #include "intfeaturespace.h"
37 #include "intfeaturemap.h"
38 #include "intmatcher.h"
39 #include "params.h"
40 #include "shapetable.h"
41 #include "trainingsample.h"
42 #include "trainingsampleset.h"
43 #include "unicharset.h"
44 
45 namespace tesseract {
46 
47 class ShapeClassifier;
48 
49 // Simple struct to hold the distance between two shapes during clustering.
50 struct ShapeDist {
51  ShapeDist() : shape1(0), shape2(0), distance(0.0f) {}
52  ShapeDist(int s1, int s2, float dist)
53  : shape1(s1), shape2(s2), distance(dist) {}
54 
55  // Sort operator to sort in ascending order of distance.
56  bool operator<(const ShapeDist& other) const {
57  return distance < other.distance;
58  }
59 
60  int shape1;
61  int shape2;
62  float distance;
63 };
64 
65 // Class to encapsulate training processes that use the TrainingSampleSet.
66 // Initially supports shape clustering and mftrainining.
67 // Other important features of the MasterTrainer are conditioning the data
68 // by outlier elimination, replication with perturbation, and serialization.
70  public:
71  MasterTrainer(NormalizationMode norm_mode, bool shape_analysis,
72  bool replicate_samples, int debug_level);
74 
75  // Writes to the given file. Returns false in case of error.
76  bool Serialize(FILE* fp) const;
77  // Reads from the given file. Returns false in case of error.
78  // If swap is true, assumes a big/little-endian swap is needed.
79  bool DeSerialize(bool swap, FILE* fp);
80 
81  // Loads an initial unicharset, or sets one up if the file cannot be read.
82  void LoadUnicharset(const char* filename);
83 
84  // Sets the feature space definition.
85  void SetFeatureSpace(const IntFeatureSpace& fs) {
86  feature_space_ = fs;
87  feature_map_.Init(fs);
88  }
89 
90  // Reads the samples and their features from the given file,
91  // adding them to the trainer with the font_id from the content of the file.
92  // If verification, then these are verification samples, not training.
93  void ReadTrainingSamples(const char* page_name,
95  bool verification);
96 
97  // Adds the given single sample to the trainer, setting the classid
98  // appropriately from the given unichar_str.
99  void AddSample(bool verification, const char* unichar_str,
101 
102  // Loads all pages from the given tif filename and append to page_images_.
103  // Must be called after ReadTrainingSamples, as the current number of images
104  // is used as an offset for page numbers in the samples.
105  void LoadPageImages(const char* filename);
106 
107  // Cleans up the samples after initial load from the tr files, and prior to
108  // saving the MasterTrainer:
109  // Remaps fragmented chars if running shape anaylsis.
110  // Sets up the samples appropriately for class/fontwise access.
111  // Deletes outlier samples.
112  void PostLoadCleanup();
113 
114  // Gets the samples ready for training. Use after both
115  // ReadTrainingSamples+PostLoadCleanup or DeSerialize.
116  // Re-indexes the features and computes canonical and cloud features.
117  void PreTrainingSetup();
118 
119  // Sets up the master_shapes_ table, which tells which fonts should stay
120  // together until they get to a leaf node classifier.
121  void SetupMasterShapes();
122 
123  // Adds the junk_samples_ to the main samples_ set. Junk samples are initially
124  // fragments and n-grams (all incorrectly segmented characters).
125  // Various training functions may result in incorrectly segmented characters
126  // being added to the unicharset of the main samples, perhaps because they
127  // form a "radical" decomposition of some (Indic) grapheme, or because they
128  // just look the same as a real character (like rn/m)
129  // This function moves all the junk samples, to the main samples_ set, but
130  // desirable junk, being any sample for which the unichar already exists in
131  // the samples_ unicharset gets the unichar-ids re-indexed to match, but
132  // anything else gets re-marked as unichar_id 0 (space character) to identify
133  // it as junk to the error counter.
134  void IncludeJunk();
135 
136  // Replicates the samples and perturbs them if the enable_replication_ flag
137  // is set. MUST be used after the last call to OrganizeByFontAndClass on
138  // the training samples, ie after IncludeJunk if it is going to be used, as
139  // OrganizeByFontAndClass will eat the replicated samples into the regular
140  // samples.
142 
143  // Loads the basic font properties file into fontinfo_table_.
144  // Returns false on failure.
145  bool LoadFontInfo(const char* filename);
146 
147  // Loads the xheight font properties file into xheights_.
148  // Returns false on failure.
149  bool LoadXHeights(const char* filename);
150 
151  // Reads spacing stats from filename and adds them to fontinfo_table.
152  // Returns false on failure.
153  bool AddSpacingInfo(const char *filename);
154 
155  // Returns the font id corresponding to the given font name.
156  // Returns -1 if the font cannot be found.
157  int GetFontInfoId(const char* font_name);
158  // Returns the font_id of the closest matching font name to the given
159  // filename. It is assumed that a substring of the filename will match
160  // one of the fonts. If more than one is matched, the longest is returned.
161  int GetBestMatchingFontInfoId(const char* filename);
162 
163  // Returns the filename of the tr file corresponding to the command-line
164  // argument with the given index.
165  const STRING& GetTRFileName(int index) const {
166  return tr_filenames_[index];
167  }
168 
169  // Sets up a flat shapetable with one shape per class/font combination.
170  void SetupFlatShapeTable(ShapeTable* shape_table);
171 
172  // Sets up a Clusterer for mftraining on a single shape_id.
173  // Call FreeClusterer on the return value after use.
174  CLUSTERER* SetupForClustering(const ShapeTable& shape_table,
176  int shape_id, int* num_samples);
177 
178  // Writes the given float_classes (produced by SetupForFloat2Int) as inttemp
179  // to the given inttemp_file, and the corresponding pffmtable.
180  // The unicharset is the original encoding of graphemes, and shape_set should
181  // match the size of the shape_table, and may possibly be totally fake.
183  const UNICHARSET& shape_set,
184  const ShapeTable& shape_table,
185  CLASS_STRUCT* float_classes,
186  const char* inttemp_file,
187  const char* pffmtable_file);
188 
189  const UNICHARSET& unicharset() const {
190  return samples_.unicharset();
191  }
193  return &samples_;
194  }
195  const ShapeTable& master_shapes() const {
196  return master_shapes_;
197  }
198 
199  // Generates debug output relating to the canonical distance between the
200  // two given UTF8 grapheme strings.
201  void DebugCanonical(const char* unichar_str1, const char* unichar_str2);
202  #ifndef GRAPHICS_DISABLED
203  // Debugging for cloud/canonical features.
204  // Displays a Features window containing:
205  // If unichar_str2 is in the unicharset, and canonical_font is non-negative,
206  // displays the canonical features of the char/font combination in red.
207  // If unichar_str1 is in the unicharset, and cloud_font is non-negative,
208  // displays the cloud feature of the char/font combination in green.
209  // The canonical features are drawn first to show which ones have no
210  // matches in the cloud features.
211  // Until the features window is destroyed, each click in the features window
212  // will display the samples that have that feature in a separate window.
213  void DisplaySamples(const char* unichar_str1, int cloud_font,
214  const char* unichar_str2, int canonical_font);
215  #endif // GRAPHICS_DISABLED
216 
217  void TestClassifierVOld(bool replicate_samples,
218  ShapeClassifier* test_classifier,
219  ShapeClassifier* old_classifier);
220 
221  // Tests the given test_classifier on the internal samples.
222  // See TestClassifier for details.
223  void TestClassifierOnSamples(CountTypes error_mode,
224  int report_level,
225  bool replicate_samples,
226  ShapeClassifier* test_classifier,
227  STRING* report_string);
228  // Tests the given test_classifier on the given samples
229  // error_mode indicates what counts as an error.
230  // report_levels:
231  // 0 = no output.
232  // 1 = bottom-line error rate.
233  // 2 = bottom-line error rate + time.
234  // 3 = font-level error rate + time.
235  // 4 = list of all errors + short classifier debug output on 16 errors.
236  // 5 = list of all errors + short classifier debug output on 25 errors.
237  // If replicate_samples is true, then the test is run on an extended test
238  // sample including replicated and systematically perturbed samples.
239  // If report_string is non-NULL, a summary of the results for each font
240  // is appended to the report_string.
241  double TestClassifier(CountTypes error_mode,
242  int report_level,
243  bool replicate_samples,
244  TrainingSampleSet* samples,
245  ShapeClassifier* test_classifier,
246  STRING* report_string);
247 
248  // Returns the average (in some sense) distance between the two given
249  // shapes, which may contain multiple fonts and/or unichars.
250  // This function is public to facilitate testing.
251  float ShapeDistance(const ShapeTable& shapes, int s1, int s2);
252 
253  private:
254  // Replaces samples that are always fragmented with the corresponding
255  // fragment samples.
256  void ReplaceFragmentedSamples();
257 
258  // Runs a hierarchical agglomerative clustering to merge shapes in the given
259  // shape_table, while satisfying the given constraints:
260  // * End with at least min_shapes left in shape_table,
261  // * No shape shall have more than max_shape_unichars in it,
262  // * Don't merge shapes where the distance between them exceeds max_dist.
263  void ClusterShapes(int min_shapes, int max_shape_unichars,
264  float max_dist, ShapeTable* shape_table);
265 
266  private:
267  NormalizationMode norm_mode_;
268  // Character set we are training for.
269  UNICHARSET unicharset_;
270  // Original feature space. Subspace mapping is contained in feature_map_.
271  IntFeatureSpace feature_space_;
272  TrainingSampleSet samples_;
273  TrainingSampleSet junk_samples_;
274  TrainingSampleSet verify_samples_;
275  // Master shape table defines what fonts stay together until the leaves.
276  ShapeTable master_shapes_;
277  // Flat shape table has each unichar/font id pair in a separate shape.
278  ShapeTable flat_shapes_;
279  // Font metrics gathered from multiple files.
280  FontInfoTable fontinfo_table_;
281  // Array of xheights indexed by font ids in fontinfo_table_;
282  GenericVector<inT32> xheights_;
283 
284  // Non-serialized data initialized by other means or used temporarily
285  // during loading of training samples.
286  // Number of different class labels in unicharset_.
287  int charsetsize_;
288  // Flag to indicate that we are running shape analysis and need fragments
289  // fixing.
290  bool enable_shape_anaylsis_;
291  // Flag to indicate that sample replication is required.
292  bool enable_replication_;
293  // Array of classids of fragments that replace the correctly segmented chars.
294  int* fragments_;
295  // Classid of previous correctly segmented sample that was added.
296  int prev_unichar_id_;
297  // Debug output control.
298  int debug_level_;
299  // Feature map used to construct reduced feature spaces for compact
300  // classifiers.
301  IntFeatureMap feature_map_;
302  // Vector of Pix pointers used for classifiers that need the image.
303  // Indexed by page_num_ in the samples.
304  // These images are owned by the trainer and need to be pixDestroyed.
305  GenericVector<Pix*> page_images_;
306  // Vector of filenames of loaded tr files.
307  GenericVector<STRING> tr_filenames_;
308 };
309 
310 } // namespace tesseract.
311 
312 #endif
bool LoadFontInfo(const char *filename)
void LoadPageImages(const char *filename)
void TestClassifierVOld(bool replicate_samples, ShapeClassifier *test_classifier, ShapeClassifier *old_classifier)
void SetupFlatShapeTable(ShapeTable *shape_table)
void AddSample(bool verification, const char *unichar_str, TrainingSample *sample)
bool Serialize(FILE *fp) const
bool operator<(const ShapeDist &other) const
Definition: mastertrainer.h:56
void DisplaySamples(const char *unichar_str1, int cloud_font, const char *unichar_str2, int canonical_font)
int GetFontInfoId(const char *font_name)
double TestClassifier(CountTypes error_mode, int report_level, bool replicate_samples, TrainingSampleSet *samples, ShapeClassifier *test_classifier, STRING *report_string)
FEATURE_DEFS_STRUCT feature_defs
void Init(const IntFeatureSpace &feature_space)
const STRING & GetTRFileName(int index) const
bool LoadXHeights(const char *filename)
void ReplicateAndRandomizeSamplesIfRequired()
const UNICHARSET & unicharset() const
int GetBestMatchingFontInfoId(const char *filename)
void LoadUnicharset(const char *filename)
void TestClassifierOnSamples(CountTypes error_mode, int report_level, bool replicate_samples, ShapeClassifier *test_classifier, STRING *report_string)
float ShapeDistance(const ShapeTable &shapes, int s1, int s2)
const UNICHARSET & unicharset() const
bool DeSerialize(bool swap, FILE *fp)
MasterTrainer(NormalizationMode norm_mode, bool shape_analysis, bool replicate_samples, int debug_level)
const ShapeTable & master_shapes() const
Definition: cluster.h:32
void WriteInttempAndPFFMTable(const UNICHARSET &unicharset, const UNICHARSET &shape_set, const ShapeTable &shape_table, CLASS_STRUCT *float_classes, const char *inttemp_file, const char *pffmtable_file)
Definition: strngs.h:44
void ReadTrainingSamples(const char *page_name, const FEATURE_DEFS_STRUCT &feature_defs, bool verification)
ShapeDist(int s1, int s2, float dist)
Definition: mastertrainer.h:52
NormalizationMode
Definition: normalis.h:44
void SetFeatureSpace(const IntFeatureSpace &fs)
Definition: mastertrainer.h:85
void DebugCanonical(const char *unichar_str1, const char *unichar_str2)
CLUSTERER * SetupForClustering(const ShapeTable &shape_table, const FEATURE_DEFS_STRUCT &feature_defs, int shape_id, int *num_samples)
TrainingSampleSet * GetSamples()
bool AddSpacingInfo(const char *filename)