tesseract  5.0.0-alpha-619-ge9db
lstmrecognizer.h
Go to the documentation of this file.
1 // File: lstmrecognizer.h
3 // Description: Top-level line recognizer class for LSTM-based networks.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
19 #define TESSERACT_LSTM_LSTMRECOGNIZER_H_
20 
21 #include "ccutil.h"
22 #include <tesseract/helpers.h>
23 #include "imagedata.h"
24 #include "matrix.h"
25 #include "network.h"
26 #include "networkscratch.h"
27 #include "params.h"
28 #include "recodebeam.h"
29 #include "series.h"
30 #include <tesseract/strngs.h>
31 #include "unicharcompress.h"
32 
33 class BLOB_CHOICE_IT;
34 struct Pix;
35 class ROW_RES;
36 class ScrollView;
37 class TBOX;
38 class WERD_RES;
39 
40 namespace tesseract {
41 
42 class Dict;
43 class ImageData;
44 
45 // Enum indicating training mode control flags.
49 };
50 
51 // Top-level line recognizer class for LSTM-based networks.
52 // Note that a sub-class, LSTMTrainer is used for training.
54  public:
56  LSTMRecognizer(const STRING language_data_path_prefix);
58 
59  int NumOutputs() const { return network_->NumOutputs(); }
60  int training_iteration() const { return training_iteration_; }
61  int sample_iteration() const { return sample_iteration_; }
62  double learning_rate() const { return learning_rate_; }
64  if (network_ == nullptr) return LT_NONE;
65  StaticShape shape;
66  shape = network_->OutputShape(shape);
67  return shape.loss_type();
68  }
69  bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
70  bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
71  // True if recoder_ is active to re-encode text to a smaller space.
72  bool IsRecoding() const {
74  }
75  // Returns true if the network is a TensorFlow network.
76  bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
77  // Returns a vector of layer ids that can be passed to other layer functions
78  // to access a specific layer.
80  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
81  auto* series = static_cast<Series*>(network_);
82  GenericVector<STRING> layers;
83  series->EnumerateLayers(nullptr, &layers);
84  return layers;
85  }
86  // Returns a specific layer from its id (from EnumerateLayers).
87  Network* GetLayer(const STRING& id) const {
88  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
89  ASSERT_HOST(id.length() > 1 && id[0] == ':');
90  auto* series = static_cast<Series*>(network_);
91  return series->GetLayer(&id[1]);
92  }
93  // Returns the learning rate of the layer from its id.
94  float GetLayerLearningRate(const STRING& id) const {
95  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
97  ASSERT_HOST(id.length() > 1 && id[0] == ':');
98  auto* series = static_cast<Series*>(network_);
99  return series->LayerLearningRate(&id[1]);
100  } else {
101  return learning_rate_;
102  }
103  }
104  // Multiplies the all the learning rate(s) by the given factor.
105  void ScaleLearningRate(double factor) {
106  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
107  learning_rate_ *= factor;
110  for (int i = 0; i < layers.size(); ++i) {
111  ScaleLayerLearningRate(layers[i], factor);
112  }
113  }
114  }
115  // Multiplies the learning rate of the layer with id, by the given factor.
116  void ScaleLayerLearningRate(const STRING& id, double factor) {
117  ASSERT_HOST(network_ != nullptr && network_->type() == NT_SERIES);
118  ASSERT_HOST(id.length() > 1 && id[0] == ':');
119  auto* series = static_cast<Series*>(network_);
120  series->ScaleLayerLearningRate(&id[1], factor);
121  }
122 
123  // Converts the network to int if not already.
124  void ConvertToInt() {
125  if ((training_flags_ & TF_INT_MODE) == 0) {
128  }
129  }
130 
131  // Provides access to the UNICHARSET that this classifier works with.
132  const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
134  // Provides access to the UnicharCompress that this classifier works with.
135  const UnicharCompress& GetRecoder() const { return recoder_; }
136  // Provides access to the Dict that this classifier works with.
137  const Dict* GetDict() const { return dict_; }
138  Dict* GetDict() { return dict_; }
139  // Sets the sample iteration to the given value. The sample_iteration_
140  // determines the seed for the random number generator. The training
141  // iteration is incremented only by a successful training iteration.
142  void SetIteration(int iteration) { sample_iteration_ = iteration; }
143  // Accessors for textline image normalization.
144  int NumInputs() const { return network_->NumInputs(); }
145  int null_char() const { return null_char_; }
146 
147  // Loads a model from mgr, including the dictionary only if lang is not null.
148  bool Load(const ParamsVectors* params, const char* lang,
149  TessdataManager* mgr);
150 
151  // Writes to the given file. Returns false in case of error.
152  // If mgr contains a unicharset and recoder, then they are not encoded to fp.
153  bool Serialize(const TessdataManager* mgr, TFile* fp) const;
154  // Reads from the given file. Returns false in case of error.
155  // If mgr contains a unicharset and recoder, then they are taken from there,
156  // otherwise, they are part of the serialization in fp.
157  bool DeSerialize(const TessdataManager* mgr, TFile* fp);
158  // Loads the charsets from mgr.
159  bool LoadCharsets(const TessdataManager* mgr);
160  // Loads the Recoder.
161  bool LoadRecoder(TFile* fp);
162  // Loads the dictionary if possible from the traineddata file.
163  // Prints a warning message, and returns false but otherwise fails silently
164  // and continues to work without it if loading fails.
165  // Note that dictionary load is independent from DeSerialize, but dependent
166  // on the unicharset matching. This enables training to deserialize a model
167  // from checkpoint or restore without having to go back and reload the
168  // dictionary.
169  bool LoadDictionary(const ParamsVectors* params, const char* lang,
170  TessdataManager* mgr);
171 
172  // Recognizes the line image, contained within image_data, returning the
173  // recognized tesseract WERD_RES for the words.
174  // If invert, tries inverted as well if the normal interpretation doesn't
175  // produce a good enough result. The line_box is used for computing the
176  // box_word in the output words. worst_dict_cert is the worst certainty that
177  // will be used in a dictionary word.
178  void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
179  double worst_dict_cert, const TBOX& line_box,
180  PointerVector<WERD_RES>* words, int lstm_choice_mode = 0,
181  int lstm_choice_amount = 5);
182 
183  // Helper computes min and mean best results in the output.
184  void OutputStats(const NetworkIO& outputs, float* min_output,
185  float* mean_output, float* sd);
186  // Recognizes the image_data, returning the labels,
187  // scores, and corresponding pairs of start, end x-coords in coords.
188  // Returned in scale_factor is the reduction factor
189  // between the image and the output coords, for computing bounding boxes.
190  // If re_invert is true, the input is inverted back to its original
191  // photometric interpretation if inversion is attempted but fails to
192  // improve the results. This ensures that outputs contains the correct
193  // forward outputs for the best photometric interpretation.
194  // inputs is filled with the used inputs to the network.
195  bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
196  bool re_invert, bool upside_down, float* scale_factor,
197  NetworkIO* inputs, NetworkIO* outputs);
198 
199  // Converts an array of labels to utf-8, whether or not the labels are
200  // augmented with character boundaries.
201  STRING DecodeLabels(const GenericVector<int>& labels);
202 
203  // Displays the forward results in a window with the characters and
204  // boundaries as determined by the labels and label_coords.
205  void DisplayForward(const NetworkIO& inputs, const GenericVector<int>& labels,
206  const GenericVector<int>& label_coords,
207  const char* window_name, ScrollView** window);
208  // Converts the network output to a sequence of labels. Outputs labels, scores
209  // and start xcoords of each char, and each null_char_, with an additional
210  // final xcoord for the end of the output.
211  // The conversion method is determined by internal state.
212  void LabelsFromOutputs(const NetworkIO& outputs, GenericVector<int>* labels,
213  GenericVector<int>* xcoords);
214 
215  protected:
216  // Sets the random seed from the sample_iteration_;
217  void SetRandomSeed() {
218  int64_t seed = static_cast<int64_t>(sample_iteration_) * 0x10000001;
219  randomizer_.set_seed(seed);
221  }
222 
223  // Displays the labels and cuts at the corresponding xcoords.
224  // Size of labels should match xcoords.
225  void DisplayLSTMOutput(const GenericVector<int>& labels,
226  const GenericVector<int>& xcoords, int height,
227  ScrollView* window);
228 
229  // Prints debug output detailing the activation path that is implied by the
230  // xcoords.
231  void DebugActivationPath(const NetworkIO& outputs,
232  const GenericVector<int>& labels,
233  const GenericVector<int>& xcoords);
234 
235  // Prints debug output detailing activations and 2nd choice over a range
236  // of positions.
237  void DebugActivationRange(const NetworkIO& outputs, const char* label,
238  int best_choice, int x_start, int x_end);
239 
240  // As LabelsViaCTC except that this function constructs the best path that
241  // contains only legal sequences of subcodes for recoder_.
242  void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
243  GenericVector<int>* xcoords);
244  // Converts the network output to a sequence of labels, with scores, using
245  // the simple character model (each position is a char, and the null_char_ is
246  // mainly intended for tail padding.)
247  void LabelsViaSimpleText(const NetworkIO& output, GenericVector<int>* labels,
248  GenericVector<int>* xcoords);
249 
250  // Returns a string corresponding to the label starting at start. Sets *end
251  // to the next start and if non-null, *decoded to the unichar id.
252  const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
253  int* decoded);
254 
255  // Returns a string corresponding to a given single label id, falling back to
256  // a default of ".." for part of a multi-label unichar-id.
257  const char* DecodeSingleLabel(int label);
258 
259  protected:
260  // The network hierarchy.
262  // The unicharset. Only the unicharset element is serialized.
263  // Has to be a CCUtil, so Dict can point to it.
265  // For backward compatibility, recoder_ is serialized iff
266  // training_flags_ & TF_COMPRESS_UNICHARSET.
267  // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
269 
270  // ==Training parameters that are serialized to provide a record of them.==
272  // Flags used to determine the training method of the network.
273  // See enum TrainingFlags above.
275  // Number of actual backward training steps used.
277  // Index into training sample set. sample_iteration >= training_iteration_.
279  // Index in softmax of null character. May take the value UNICHAR_BROKEN or
280  // ccutil_.unicharset.size().
281  int32_t null_char_;
282  // Learning rate and momentum multipliers of deltas in backprop.
284  float momentum_;
285  // Smoothing factor for 2nd moment of gradients.
286  float adam_beta_;
287 
288  // === NOT SERIALIZED.
291  // Language model (optional) to use with the beam search.
293  // Beam search held between uses to optimize memory allocation/use.
295 
296  // == Debugging parameters.==
297  // Recognition debug display window.
299 };
300 
301 } // namespace tesseract.
302 
303 #endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_
tesseract::StaticShape
Definition: static_shape.h:38
tesseract::LSTMRecognizer::learning_rate_
float learning_rate_
Definition: lstmrecognizer.h:283
ScrollView
Definition: scrollview.h:97
strngs.h
tesseract::LSTMRecognizer::dict_
Dict * dict_
Definition: lstmrecognizer.h:292
tesseract::StaticShape::loss_type
LossType loss_type() const
Definition: static_shape.h:50
tesseract::LSTMRecognizer::DebugActivationPath
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
Definition: lstmrecognizer.cpp:392
tesseract::LSTMRecognizer::IsTensorFlow
bool IsTensorFlow() const
Definition: lstmrecognizer.h:76
tesseract::LSTMRecognizer::LabelsViaSimpleText
void LabelsViaSimpleText(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
Definition: lstmrecognizer.cpp:488
tesseract::RecodeBeamSearch
Definition: recodebeam.h:180
tesseract::TessdataManager
Definition: tessdatamanager.h:126
tesseract::LSTMRecognizer::ScaleLayerLearningRate
void ScaleLayerLearningRate(const STRING &id, double factor)
Definition: lstmrecognizer.h:116
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:87
tesseract::LSTMRecognizer::IsIntMode
bool IsIntMode() const
Definition: lstmrecognizer.h:70
tesseract::LSTMRecognizer::search_
RecodeBeamSearch * search_
Definition: lstmrecognizer.h:294
params.h
tesseract::LSTMRecognizer::training_iteration
int training_iteration() const
Definition: lstmrecognizer.h:60
tesseract::LSTMRecognizer::LoadDictionary
bool LoadDictionary(const ParamsVectors *params, const char *lang, TessdataManager *mgr)
Definition: lstmrecognizer.cpp:167
tesseract::TrainingFlags
TrainingFlags
Definition: lstmrecognizer.h:46
tesseract::LSTMRecognizer::randomizer_
TRand randomizer_
Definition: lstmrecognizer.h:289
tesseract::PointerVector< WERD_RES >
tesseract::LSTMRecognizer::learning_rate
double learning_rate() const
Definition: lstmrecognizer.h:62
tesseract::TRand::IntRand
int32_t IntRand()
Definition: helpers.h:80
tesseract::LSTMRecognizer::GetDict
Dict * GetDict()
Definition: lstmrecognizer.h:138
STRING
Definition: strngs.h:45
recodebeam.h
WERD_RES
Definition: pageres.h:160
tesseract::NetworkScratch
Definition: networkscratch.h:34
tesseract::LSTMRecognizer::sample_iteration_
int32_t sample_iteration_
Definition: lstmrecognizer.h:278
network.h
tesseract::Network::type
NetworkType type() const
Definition: network.h:112
tesseract::LSTMRecognizer::network_str_
STRING network_str_
Definition: lstmrecognizer.h:271
tesseract::LSTMRecognizer::EnumerateLayers
GenericVector< STRING > EnumerateLayers() const
Definition: lstmrecognizer.h:79
tesseract::LSTMRecognizer::DebugActivationRange
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
Definition: lstmrecognizer.cpp:419
tesseract::LSTMRecognizer::LabelsFromOutputs
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
Definition: lstmrecognizer.cpp:462
tesseract::LSTMRecognizer::DisplayForward
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
Definition: lstmrecognizer.cpp:349
tesseract::LSTMRecognizer::GetUnicharset
const UNICHARSET & GetUnicharset() const
Definition: lstmrecognizer.h:132
tesseract::ImageData
Definition: imagedata.h:104
tesseract::CCUtil::unicharset
UNICHARSET unicharset
Definition: ccutil.h:57
tesseract::LSTMRecognizer::LSTMRecognizer
LSTMRecognizer()
Definition: lstmrecognizer.cpp:57
tesseract::LSTMRecognizer::LoadRecoder
bool LoadRecoder(TFile *fp)
Definition: lstmrecognizer.cpp:143
tesseract::Network::TestFlag
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
tesseract::LSTMRecognizer::ccutil_
CCUtil ccutil_
Definition: lstmrecognizer.h:264
tesseract::LSTMRecognizer::GetLayerLearningRate
float GetLayerLearningRate(const STRING &id) const
Definition: lstmrecognizer.h:94
tesseract::LT_NONE
Definition: static_shape.h:30
tesseract::LSTMRecognizer::LoadCharsets
bool LoadCharsets(const TessdataManager *mgr)
Definition: lstmrecognizer.cpp:133
tesseract::Network::OutputShape
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133
tesseract::LSTMRecognizer::SetIteration
void SetIteration(int iteration)
Definition: lstmrecognizer.h:142
tesseract::LSTMRecognizer::DeSerialize
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
Definition: lstmrecognizer.cpp:108
networkscratch.h
tesseract::LSTMRecognizer::RecognizeLine
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
Definition: lstmrecognizer.cpp:187
tesseract::LSTMRecognizer::scratch_space_
NetworkScratch scratch_space_
Definition: lstmrecognizer.h:290
tesseract::LSTMRecognizer::GetRecoder
const UnicharCompress & GetRecoder() const
Definition: lstmrecognizer.h:135
tesseract::LossType
LossType
Definition: static_shape.h:29
tesseract::LSTMRecognizer::training_flags_
int32_t training_flags_
Definition: lstmrecognizer.h:274
tesseract::LSTMRecognizer::null_char
int null_char() const
Definition: lstmrecognizer.h:145
tesseract::NT_SERIES
Definition: network.h:54
tesseract::LSTMRecognizer::SetRandomSeed
void SetRandomSeed()
Definition: lstmrecognizer.h:217
tesseract::LSTMRecognizer::momentum_
float momentum_
Definition: lstmrecognizer.h:284
tesseract::LSTMRecognizer::debug_win_
ScrollView * debug_win_
Definition: lstmrecognizer.h:298
tesseract::Network::ConvertToInt
virtual void ConvertToInt()
Definition: network.h:191
tesseract::LSTMRecognizer::ConvertToInt
void ConvertToInt()
Definition: lstmrecognizer.h:124
ccutil.h
tesseract::LSTMRecognizer::OutputLossType
LossType OutputLossType() const
Definition: lstmrecognizer.h:63
tesseract::LSTMRecognizer::ScaleLearningRate
void ScaleLearningRate(double factor)
Definition: lstmrecognizer.h:105
matrix.h
tesseract::TFile
Definition: serialis.h:75
UNICHARSET
Definition: unicharset.h:145
tesseract::NetworkIO
Definition: networkio.h:39
tesseract::ParamsVectors
Definition: params.h:56
tesseract::LSTMRecognizer::GetLayer
Network * GetLayer(const STRING &id) const
Definition: lstmrecognizer.h:87
tesseract::LSTMRecognizer::Serialize
bool Serialize(const TessdataManager *mgr, TFile *fp) const
Definition: lstmrecognizer.cpp:89
tesseract::LSTMRecognizer::recoder_
UnicharCompress recoder_
Definition: lstmrecognizer.h:268
tesseract::LSTMRecognizer::training_iteration_
int32_t training_iteration_
Definition: lstmrecognizer.h:276
helpers.h
tesseract
Definition: baseapi.h:65
tesseract::TF_INT_MODE
Definition: lstmrecognizer.h:47
tesseract::LSTMRecognizer::OutputStats
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
Definition: lstmrecognizer.cpp:238
tesseract::LSTMRecognizer::NumInputs
int NumInputs() const
Definition: lstmrecognizer.h:144
tesseract::LSTMRecognizer::SimpleTextOutput
bool SimpleTextOutput() const
Definition: lstmrecognizer.h:69
tesseract::LSTMRecognizer::null_char_
int32_t null_char_
Definition: lstmrecognizer.h:281
tesseract::NT_TENSORFLOW
Definition: network.h:78
tesseract::Network::NumOutputs
int NumOutputs() const
Definition: network.h:123
GenericVector< STRING >
tesseract::Dict
Definition: dict.h:91
tesseract::LT_SOFTMAX
Definition: static_shape.h:32
tesseract::TRand::set_seed
void set_seed(uint64_t seed)
Definition: helpers.h:70
tesseract::Network
Definition: network.h:105
series.h
tesseract::LSTMRecognizer::LabelsViaReEncode
void LabelsViaReEncode(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
Definition: lstmrecognizer.cpp:474
tesseract::LSTMRecognizer::~LSTMRecognizer
~LSTMRecognizer()
Definition: lstmrecognizer.cpp:70
imagedata.h
ROW_RES
Definition: pageres.h:133
tesseract::LSTMRecognizer::DecodeSingleLabel
const char * DecodeSingleLabel(int label)
Definition: lstmrecognizer.cpp:549
tesseract::LSTMRecognizer
Definition: lstmrecognizer.h:53
tesseract::LSTMRecognizer::IsRecoding
bool IsRecoding() const
Definition: lstmrecognizer.h:72
unicharcompress.h
tesseract::LSTMRecognizer::NumOutputs
int NumOutputs() const
Definition: lstmrecognizer.h:59
tesseract::LSTMRecognizer::GetUnicharset
UNICHARSET & GetUnicharset()
Definition: lstmrecognizer.h:133
tesseract::LSTMRecognizer::network_
Network * network_
Definition: lstmrecognizer.h:261
tesseract::LSTMRecognizer::GetDict
const Dict * GetDict() const
Definition: lstmrecognizer.h:137
tesseract::LSTMRecognizer::DisplayLSTMOutput
void DisplayLSTMOutput(const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)
Definition: lstmrecognizer.cpp:365
tesseract::LSTMRecognizer::DecodeLabel
const char * DecodeLabel(const GenericVector< int > &labels, int start, int *end, int *decoded)
Definition: lstmrecognizer.cpp:507
tesseract::NF_LAYER_SPECIFIC_LR
Definition: network.h:87
tesseract::TF_COMPRESS_UNICHARSET
Definition: lstmrecognizer.h:48
tesseract::Network::NumInputs
int NumInputs() const
Definition: network.h:120
tesseract::LSTMRecognizer::Load
bool Load(const ParamsVectors *params, const char *lang, TessdataManager *mgr)
Definition: lstmrecognizer.cpp:77
tesseract::UnicharCompress
Definition: unicharcompress.h:128
tesseract::TRand
Definition: helpers.h:50
GenericVector::size
int size() const
Definition: genericvector.h:71
tesseract::LSTMRecognizer::adam_beta_
float adam_beta_
Definition: lstmrecognizer.h:286
tesseract::CCUtil
Definition: ccutil.h:40
tesseract::LSTMRecognizer::sample_iteration
int sample_iteration() const
Definition: lstmrecognizer.h:61
TBOX
Definition: rect.h:33
tesseract::LSTMRecognizer::DecodeLabels
STRING DecodeLabels(const GenericVector< int > &labels)
Definition: lstmrecognizer.cpp:334