tesseract  5.0.0-alpha-619-ge9db
network.cpp
Go to the documentation of this file.
1 // File: network.cpp
3 // Description: Base class for neural network implementations.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2013, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 // Include automatically generated configuration file if running autoconf.
19 #ifdef HAVE_CONFIG_H
20 #include "config_auto.h"
21 #endif
22 
23 #include "network.h"
24 
25 #include <cstdlib>
26 
27 // This base class needs to know about all its sub-classes because of the
28 // factory deserializing method: CreateFromFile.
29 #include "allheaders.h"
30 #include "convolve.h"
31 #include "fullyconnected.h"
32 #include "input.h"
33 #include "lstm.h"
34 #include "maxpool.h"
35 #include "parallel.h"
36 #include "reconfig.h"
37 #include "reversed.h"
38 #include "scrollview.h"
39 #include "series.h"
40 #include "statistc.h"
41 #ifdef INCLUDE_TENSORFLOW
42 #include "tfnetwork.h"
43 #endif
44 #include "tprintf.h"
45 
46 namespace tesseract {
47 
48 // Min and max window sizes.
49 const int kMinWinSize = 500;
50 const int kMaxWinSize = 2000;
51 // Window frame sizes need adding on to make the content fit.
52 const int kXWinFrameSize = 30;
53 const int kYWinFrameSize = 80;
54 
55 // String names corresponding to the NetworkType enum.
56 // Keep in sync with NetworkType.
57 // Names used in Serialization to allow re-ordering/addition/deletion of
58 // layer types in NetworkType without invalidating existing network files.
59 static char const* const kTypeNames[NT_COUNT] = {
60  "Invalid", "Input",
61  "Convolve", "Maxpool",
62  "Parallel", "Replicated",
63  "ParBidiLSTM", "DepParUDLSTM",
64  "Par2dLSTM", "Series",
65  "Reconfig", "RTLReversed",
66  "TTBReversed", "XYTranspose",
67  "LSTM", "SummLSTM",
68  "Logistic", "LinLogistic",
69  "LinTanh", "Tanh",
70  "Relu", "Linear",
71  "Softmax", "SoftmaxNoCTC",
72  "LSTMSoftmax", "LSTMBinarySoftmax",
73  "TensorFlow",
74 };
75 
77  : type_(NT_NONE),
78  training_(TS_ENABLED),
79  needs_to_backprop_(true),
80  network_flags_(0),
81  ni_(0),
82  no_(0),
83  num_weights_(0),
84  forward_win_(nullptr),
85  backward_win_(nullptr),
86  randomizer_(nullptr) {}
87 Network::Network(NetworkType type, const STRING& name, int ni, int no)
88  : type_(type),
89  training_(TS_ENABLED),
90  needs_to_backprop_(true),
91  network_flags_(0),
92  ni_(ni),
93  no_(no),
94  num_weights_(0),
95  name_(name),
96  forward_win_(nullptr),
97  backward_win_(nullptr),
98  randomizer_(nullptr) {}
99 
100 
101 // Suspends/Enables/Permanently disables training by setting the training_
102 // flag. Serialize and DeSerialize only operate on the run-time data if state
103 // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
104 // temporarily disable layers in state TS_ENABLED, allowing a trainer to
105 // serialize as if it were a recognizer.
106 // TS_RE_ENABLE will re-enable layers that were previously in any disabled
107 // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
108 // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
109 // recognizer can be converted back to a trainer.
111  if (state == TS_RE_ENABLE) {
112  // Enable only from temp disabled.
114  } else if (state == TS_TEMP_DISABLE) {
115  // Temp disable only from enabled.
116  if (training_ == TS_ENABLED) training_ = state;
117  } else {
118  training_ = state;
119  }
120 }
121 
122 // Sets flags that control the action of the network. See NetworkFlags enum
123 // for bit values.
124 void Network::SetNetworkFlags(uint32_t flags) {
125  network_flags_ = flags;
126 }
127 
128 // Sets up the network for training. Initializes weights using weights of
129 // scale `range` picked according to the random number generator `randomizer`.
130 int Network::InitWeights(float range, TRand* randomizer) {
131  randomizer_ = randomizer;
132  return 0;
133 }
134 
135 // Provides a pointer to a TRand for any networks that care to use it.
136 // Note that randomizer is a borrowed pointer that should outlive the network
137 // and should not be deleted by any of the networks.
138 void Network::SetRandomizer(TRand* randomizer) {
139  randomizer_ = randomizer;
140 }
141 
142 // Sets needs_to_backprop_ to needs_backprop and returns true if
143 // needs_backprop || any weights in this network so the next layer forward
144 // can be told to produce backprop for this layer if needed.
145 bool Network::SetupNeedsBackprop(bool needs_backprop) {
146  needs_to_backprop_ = needs_backprop;
147  return needs_backprop || num_weights_ > 0;
148 }
149 
150 // Writes to the given file. Returns false in case of error.
151 bool Network::Serialize(TFile* fp) const {
152  int8_t data = NT_NONE;
153  if (!fp->Serialize(&data)) return false;
154  STRING type_name = kTypeNames[type_];
155  if (!type_name.Serialize(fp)) return false;
156  data = training_;
157  if (!fp->Serialize(&data)) return false;
158  data = needs_to_backprop_;
159  if (!fp->Serialize(&data)) return false;
160  if (!fp->Serialize(&network_flags_)) return false;
161  if (!fp->Serialize(&ni_)) return false;
162  if (!fp->Serialize(&no_)) return false;
163  if (!fp->Serialize(&num_weights_)) return false;
164  if (!name_.Serialize(fp)) return false;
165  return true;
166 }
167 
168 static NetworkType getNetworkType(TFile* fp) {
169  int8_t data;
170  if (!fp->DeSerialize(&data)) return NT_NONE;
171  if (data == NT_NONE) {
172  STRING type_name;
173  if (!type_name.DeSerialize(fp)) return NT_NONE;
174  for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
175  }
176  if (data == NT_COUNT) {
177  tprintf("Invalid network layer type:%s\n", type_name.c_str());
178  return NT_NONE;
179  }
180  }
181  return static_cast<NetworkType>(data);
182 }
183 
184 // Reads from the given file. Returns nullptr in case of error.
185 // Determines the type of the serialized class and calls its DeSerialize
186 // on a new object of the appropriate type, which is returned.
188  NetworkType type; // Type of the derived network class.
189  TrainingState training; // Are we currently training?
190  bool needs_to_backprop; // This network needs to output back_deltas.
191  int32_t network_flags; // Behavior control flags in NetworkFlags.
192  int32_t ni; // Number of input values.
193  int32_t no; // Number of output values.
194  int32_t num_weights; // Number of weights in this and sub-network.
195  STRING name; // A unique name for this layer.
196  int8_t data;
197  Network* network = nullptr;
198  type = getNetworkType(fp);
199  if (!fp->DeSerialize(&data)) return nullptr;
200  training = data == TS_ENABLED ? TS_ENABLED : TS_DISABLED;
201  if (!fp->DeSerialize(&data)) return nullptr;
202  needs_to_backprop = data != 0;
203  if (!fp->DeSerialize(&network_flags)) return nullptr;
204  if (!fp->DeSerialize(&ni)) return nullptr;
205  if (!fp->DeSerialize(&no)) return nullptr;
206  if (!fp->DeSerialize(&num_weights)) return nullptr;
207  if (!name.DeSerialize(fp)) return nullptr;
208 
209  switch (type) {
210  case NT_CONVOLVE:
211  network = new Convolve(name, ni, 0, 0);
212  break;
213  case NT_INPUT:
214  network = new Input(name, ni, no);
215  break;
216  case NT_LSTM:
217  case NT_LSTM_SOFTMAX:
219  case NT_LSTM_SUMMARY:
220  network =
221  new LSTM(name, ni, no, no, false, type);
222  break;
223  case NT_MAXPOOL:
224  network = new Maxpool(name, ni, 0, 0);
225  break;
226  // All variants of Parallel.
227  case NT_PARALLEL:
228  case NT_REPLICATED:
229  case NT_PAR_RL_LSTM:
230  case NT_PAR_UD_LSTM:
231  case NT_PAR_2D_LSTM:
232  network = new Parallel(name, type);
233  break;
234  case NT_RECONFIG:
235  network = new Reconfig(name, ni, 0, 0);
236  break;
237  // All variants of reversed.
238  case NT_XREVERSED:
239  case NT_YREVERSED:
240  case NT_XYTRANSPOSE:
241  network = new Reversed(name, type);
242  break;
243  case NT_SERIES:
244  network = new Series(name);
245  break;
246  case NT_TENSORFLOW:
247 #ifdef INCLUDE_TENSORFLOW
248  network = new TFNetwork(name);
249 #else
250  tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
251 #endif
252  break;
253  // All variants of FullyConnected.
254  case NT_SOFTMAX:
255  case NT_SOFTMAX_NO_CTC:
256  case NT_RELU:
257  case NT_TANH:
258  case NT_LINEAR:
259  case NT_LOGISTIC:
260  case NT_POSCLIP:
261  case NT_SYMCLIP:
262  network = new FullyConnected(name, ni, no, type);
263  break;
264  default:
265  break;
266  }
267  if (network) {
268  network->training_ = training;
270  network->network_flags_ = network_flags;
271  network->num_weights_ = num_weights;
272  if (!network->DeSerialize(fp)) {
273  delete network;
274  network = nullptr;
275  }
276  }
277  return network;
278 }
279 
280 // Returns a random number in [-range, range].
281 double Network::Random(double range) {
282  ASSERT_HOST(randomizer_ != nullptr);
283  return randomizer_->SignedRand(range);
284 }
285 
286 // === Debug image display methods. ===
287 // Displays the image of the matrix to the forward window.
288 void Network::DisplayForward(const NetworkIO& matrix) {
289 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
290  Pix* image = matrix.ToPix();
291  ClearWindow(false, name_.c_str(), pixGetWidth(image),
292  pixGetHeight(image), &forward_win_);
293  DisplayImage(image, forward_win_);
294  forward_win_->Update();
295 #endif // GRAPHICS_DISABLED
296 }
297 
298 // Displays the image of the matrix to the backward window.
299 void Network::DisplayBackward(const NetworkIO& matrix) {
300 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
301  Pix* image = matrix.ToPix();
302  STRING window_name = name_ + "-back";
303  ClearWindow(false, window_name.c_str(), pixGetWidth(image),
304  pixGetHeight(image), &backward_win_);
305  DisplayImage(image, backward_win_);
307 #endif // GRAPHICS_DISABLED
308 }
309 
310 #ifndef GRAPHICS_DISABLED
311 // Creates the window if needed, otherwise clears it.
312 void Network::ClearWindow(bool tess_coords, const char* window_name,
313  int width, int height, ScrollView** window) {
314  if (*window == nullptr) {
315  int min_size = std::min(width, height);
316  if (min_size < kMinWinSize) {
317  if (min_size < 1) min_size = 1;
318  width = width * kMinWinSize / min_size;
319  height = height * kMinWinSize / min_size;
320  }
321  width += kXWinFrameSize;
322  height += kYWinFrameSize;
323  if (width > kMaxWinSize) width = kMaxWinSize;
324  if (height > kMaxWinSize) height = kMaxWinSize;
325  *window = new ScrollView(window_name, 80, 100, width, height, width, height,
326  tess_coords);
327  tprintf("Created window %s of size %d, %d\n", window_name, width, height);
328  } else {
329  (*window)->Clear();
330  }
331 }
332 
333 // Displays the pix in the given window. and returns the height of the pix.
334 // The pix is pixDestroyed.
335 int Network::DisplayImage(Pix* pix, ScrollView* window) {
336  int height = pixGetHeight(pix);
337  window->Image(pix, 0, 0);
338  pixDestroy(&pix);
339  return height;
340 }
341 #endif // GRAPHICS_DISABLED
342 
343 } // namespace tesseract.
tesseract::TS_ENABLED
Definition: network.h:95
ScrollView
Definition: scrollview.h:97
tesseract::NT_PARALLEL
Definition: network.h:49
tesseract::NT_POSCLIP
Definition: network.h:63
tesseract::NT_PAR_2D_LSTM
Definition: network.h:53
tesseract::Network::SetRandomizer
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
tesseract::NT_XYTRANSPOSE
Definition: network.h:58
tesseract::Network::DisplayForward
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:288
tesseract::Network::SetEnableTraining
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:110
tesseract::NT_SOFTMAX_NO_CTC
Definition: network.h:69
tesseract::NT_PAR_RL_LSTM
Definition: network.h:51
tesseract::NT_COUNT
Definition: network.h:80
tesseract::Network::needs_to_backprop
bool needs_to_backprop() const
Definition: network.h:116
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:87
tesseract::Parallel
Definition: parallel.h:27
tesseract::Series
Definition: series.h:27
tesseract::Network::InitWeights
virtual int InitWeights(float range, TRand *randomizer)
Definition: network.cpp:130
tesseract::Network::backward_win_
ScrollView * backward_win_
Definition: network.h:304
STRING
Definition: strngs.h:45
tesseract::Network::SetupNeedsBackprop
virtual bool SetupNeedsBackprop(bool needs_backprop)
Definition: network.cpp:145
tesseract::Reconfig
Definition: reconfig.h:32
parallel.h
ScrollView::Image
void Image(struct Pix *image, int x_pos, int y_pos)
Definition: scrollview.cpp:763
tesseract::Maxpool
Definition: maxpool.h:29
network.h
tesseract::Network::type
NetworkType type() const
Definition: network.h:112
tesseract::NT_REPLICATED
Definition: network.h:50
tesseract::NetworkType
NetworkType
Definition: network.h:43
tesseract::Network::needs_to_backprop_
bool needs_to_backprop_
Definition: network.h:295
tesseract::NT_LSTM
Definition: network.h:60
tesseract::NT_SYMCLIP
Definition: network.h:64
STRING::DeSerialize
bool DeSerialize(bool swap, FILE *fp)
Definition: strngs.cpp:157
STRING::Serialize
bool Serialize(FILE *fp) const
Definition: strngs.cpp:144
tesseract::LSTM
Definition: lstm.h:28
tesseract::kMinWinSize
const int kMinWinSize
Definition: network.cpp:49
statistc.h
tesseract::Network::name_
STRING name_
Definition: network.h:300
maxpool.h
tesseract::FullyConnected
Definition: fullyconnected.h:28
tesseract::Network::CreateFromFile
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:187
tesseract::NT_SERIES
Definition: network.h:54
tesseract::Network::type_
NetworkType type_
Definition: network.h:293
tesseract::NetworkIO::ToPix
Pix * ToPix() const
Definition: networkio.cpp:286
STRING::c_str
const char * c_str() const
Definition: strngs.cpp:192
tesseract::NT_PAR_UD_LSTM
Definition: network.h:52
tesseract::TFile::DeSerialize
bool DeSerialize(char *data, size_t count=1)
Definition: serialis.cpp:117
tesseract::TFile::Serialize
bool Serialize(const char *data, size_t count=1)
Definition: serialis.cpp:161
tesseract::NT_YREVERSED
Definition: network.h:57
tesseract::Network::forward_win_
ScrollView * forward_win_
Definition: network.h:303
tesseract::NT_TANH
Definition: network.h:65
tesseract::TFile
Definition: serialis.h:75
tesseract::NetworkIO
Definition: networkio.h:39
tesseract::Network::randomizer_
TRand * randomizer_
Definition: network.h:305
tesseract::Network::training_
TrainingState training_
Definition: network.h:294
tesseract::Convolve
Definition: convolve.h:32
tesseract::NT_CONVOLVE
Definition: network.h:47
tesseract::TS_RE_ENABLE
Definition: network.h:99
tesseract
Definition: baseapi.h:65
tesseract::Network::SetNetworkFlags
virtual void SetNetworkFlags(uint32_t flags)
Definition: network.cpp:124
tesseract::NT_INPUT
Definition: network.h:45
lstm.h
tesseract::NT_TENSORFLOW
Definition: network.h:78
tprintf.h
tesseract::TRand::SignedRand
double SignedRand(double range)
Definition: helpers.h:85
tesseract::NT_XREVERSED
Definition: network.h:56
tesseract::TS_DISABLED
Definition: network.h:94
tesseract::NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:76
tesseract::Network
Definition: network.h:105
tesseract::kXWinFrameSize
const int kXWinFrameSize
Definition: network.cpp:52
tesseract::Network::num_weights_
int32_t num_weights_
Definition: network.h:299
series.h
reconfig.h
tesseract::Network::name
const STRING & name() const
Definition: network.h:138
tesseract::Network::ClearWindow
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:312
tesseract::Reversed
Definition: reversed.h:28
fullyconnected.h
tesseract::NT_RELU
Definition: network.h:66
tfnetwork.h
tesseract::TrainingState
TrainingState
Definition: network.h:92
tesseract::kMaxWinSize
const int kMaxWinSize
Definition: network.cpp:50
tesseract::NT_NONE
Definition: network.h:44
tesseract::TS_TEMP_DISABLE
Definition: network.h:97
tesseract::NT_LSTM_SOFTMAX
Definition: network.h:75
tesseract::NT_LSTM_SUMMARY
Definition: network.h:61
tesseract::Network::DisplayImage
static int DisplayImage(Pix *pix, ScrollView *window)
Definition: network.cpp:335
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
tesstrain_utils.type
type
Definition: tesstrain_utils.py:141
convolve.h
ScrollView::Update
static void Update()
Definition: scrollview.cpp:708
tesseract::Network::Random
double Random(double range)
Definition: network.cpp:281
tesseract::NT_LINEAR
Definition: network.h:67
tesseract::Network::Serialize
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
tesseract::NT_LOGISTIC
Definition: network.h:62
reversed.h
tesseract::kYWinFrameSize
const int kYWinFrameSize
Definition: network.cpp:53
tesseract::Network::no_
int32_t no_
Definition: network.h:298
tesseract::Network::ni_
int32_t ni_
Definition: network.h:297
tesseract::TRand
Definition: helpers.h:50
tesseract::Network::num_weights
int num_weights() const
Definition: network.h:119
tesseract::Input
Definition: input.h:27
tesseract::NT_MAXPOOL
Definition: network.h:48
scrollview.h
input.h
tesseract::Network::DisplayBackward
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:299
tesseract::Network::network_flags_
int32_t network_flags_
Definition: network.h:296
tesseract::NT_RECONFIG
Definition: network.h:55
tesseract::NT_SOFTMAX
Definition: network.h:68
tesseract::Network::DeSerialize
virtual bool DeSerialize(TFile *fp)=0
tesseract::Network::Network
Network()
Definition: network.cpp:76