tesseract  5.0.0-alpha-619-ge9db
networkbuilder.cpp
Go to the documentation of this file.
1 // File: networkbuilder.cpp
3 // Description: Class to parse the network description language and
4 // build a corresponding network.
5 // Author: Ray Smith
6 // Created: Wed Jul 16 18:35:38 PST 2014
7 //
8 // (C) Copyright 2014, Google Inc.
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 // http://www.apache.org/licenses/LICENSE-2.0
13 // Unless required by applicable law or agreed to in writing, software
14 // distributed under the License is distributed on an "AS IS" BASIS,
15 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 // See the License for the specific language governing permissions and
17 // limitations under the License.
19 
20 #include "networkbuilder.h"
21 #include "convolve.h"
22 #include "fullyconnected.h"
23 #include "input.h"
24 #include "lstm.h"
25 #include "maxpool.h"
26 #include "network.h"
27 #include "parallel.h"
28 #include "reconfig.h"
29 #include "reversed.h"
30 #include "series.h"
31 #include "unicharset.h"
32 
33 namespace tesseract {
34 
35 // Builds a network with a network_spec in the network description
36 // language, to recognize a character set of num_outputs size.
37 // If append_index is non-negative, then *network must be non-null and the
38 // given network_spec will be appended to *network AFTER append_index, with
39 // the top of the input *network discarded.
40 // Note that network_spec is call by value to allow a non-const char* pointer
41 // into the string for BuildFromString.
42 // net_flags control network behavior according to the NetworkFlags enum.
43 // The resulting network is returned via **network.
44 // Returns false if something failed.
45 bool NetworkBuilder::InitNetwork(int num_outputs, STRING network_spec,
46  int append_index, int net_flags,
47  float weight_range, TRand* randomizer,
48  Network** network) {
49  NetworkBuilder builder(num_outputs);
50  Series* bottom_series = nullptr;
51  StaticShape input_shape;
52  if (append_index >= 0) {
53  // Split the current network after the given append_index.
54  ASSERT_HOST(*network != nullptr && (*network)->type() == NT_SERIES);
55  auto* series = static_cast<Series*>(*network);
56  Series* top_series = nullptr;
57  series->SplitAt(append_index, &bottom_series, &top_series);
58  if (bottom_series == nullptr || top_series == nullptr) {
59  tprintf("Yikes! Splitting current network failed!!\n");
60  return false;
61  }
62  input_shape = bottom_series->OutputShape(input_shape);
63  delete top_series;
64  }
65  char* str_ptr = &network_spec[0];
66  *network = builder.BuildFromString(input_shape, &str_ptr);
67  if (*network == nullptr) return false;
68  (*network)->SetNetworkFlags(net_flags);
69  (*network)->InitWeights(weight_range, randomizer);
70  (*network)->SetupNeedsBackprop(false);
71  if (bottom_series != nullptr) {
72  bottom_series->AppendSeries(*network);
73  *network = bottom_series;
74  }
75  (*network)->CacheXScaleFactor((*network)->XScaleFactor());
76  return true;
77 }
78 
79 // Helper skips whitespace.
80 static void SkipWhitespace(char** str) {
81  while (**str == ' ' || **str == '\t' || **str == '\n') ++*str;
82 }
83 
84 // Parses the given string and returns a network according to the network
85 // description language in networkbuilder.h
87  char** str) {
88  SkipWhitespace(str);
89  char code_ch = **str;
90  if (code_ch == '[') {
91  return ParseSeries(input_shape, nullptr, str);
92  }
93  if (input_shape.depth() == 0) {
94  // There must be an input at this point.
95  return ParseInput(str);
96  }
97  switch (code_ch) {
98  case '(':
99  return ParseParallel(input_shape, str);
100  case 'R':
101  return ParseR(input_shape, str);
102  case 'S':
103  return ParseS(input_shape, str);
104  case 'C':
105  return ParseC(input_shape, str);
106  case 'M':
107  return ParseM(input_shape, str);
108  case 'L':
109  return ParseLSTM(input_shape, str);
110  case 'F':
111  return ParseFullyConnected(input_shape, str);
112  case 'O':
113  return ParseOutput(input_shape, str);
114  default:
115  tprintf("Invalid network spec:%s\n", *str);
116  return nullptr;
117  }
118  return nullptr;
119 }
120 
121 // Parses an input specification and returns the result, which may include a
122 // series.
123 Network* NetworkBuilder::ParseInput(char** str) {
124  // There must be an input at this point.
125  int length = 0;
126  int batch, height, width, depth;
127  int num_converted =
128  sscanf(*str, "%d,%d,%d,%d%n", &batch, &height, &width, &depth, &length);
129  StaticShape shape;
130  shape.SetShape(batch, height, width, depth);
131  // num_converted may or may not include the length.
132  if (num_converted != 4 && num_converted != 5) {
133  tprintf("Must specify an input layer as the first layer, not %s!!\n", *str);
134  return nullptr;
135  }
136  *str += length;
137  Input* input = new Input("Input", shape);
138  // We want to allow [<input>rest of net... or <input>[rest of net... so we
139  // have to check explicitly for '[' here.
140  SkipWhitespace(str);
141  if (**str == '[') return ParseSeries(shape, input, str);
142  return input;
143 }
144 
145 // Parses a sequential series of networks, defined by [<net><net>...].
146 Network* NetworkBuilder::ParseSeries(const StaticShape& input_shape,
147  Input* input_layer, char** str) {
148  StaticShape shape = input_shape;
149  Series* series = new Series("Series");
150  ++*str;
151  if (input_layer != nullptr) {
152  series->AddToStack(input_layer);
153  shape = input_layer->OutputShape(shape);
154  }
155  Network* network = nullptr;
156  while (**str != '\0' && **str != ']' &&
157  (network = BuildFromString(shape, str)) != nullptr) {
158  shape = network->OutputShape(shape);
159  series->AddToStack(network);
160  }
161  if (**str != ']') {
162  tprintf("Missing ] at end of [Series]!\n");
163  delete series;
164  return nullptr;
165  }
166  ++*str;
167  return series;
168 }
169 
170 // Parses a parallel set of networks, defined by (<net><net>...).
171 Network* NetworkBuilder::ParseParallel(const StaticShape& input_shape,
172  char** str) {
173  Parallel* parallel = new Parallel("Parallel", NT_PARALLEL);
174  ++*str;
175  Network* network = nullptr;
176  while (**str != '\0' && **str != ')' &&
177  (network = BuildFromString(input_shape, str)) != nullptr) {
178  parallel->AddToStack(network);
179  }
180  if (**str != ')') {
181  tprintf("Missing ) at end of (Parallel)!\n");
182  delete parallel;
183  return nullptr;
184  }
185  ++*str;
186  return parallel;
187 }
188 
189 // Parses a network that begins with 'R'.
190 Network* NetworkBuilder::ParseR(const StaticShape& input_shape, char** str) {
191  char dir = (*str)[1];
192  if (dir == 'x' || dir == 'y') {
193  STRING name = "Reverse";
194  name += dir;
195  *str += 2;
196  Network* network = BuildFromString(input_shape, str);
197  if (network == nullptr) return nullptr;
198  auto* rev =
199  new Reversed(name, dir == 'y' ? NT_YREVERSED : NT_XREVERSED);
200  rev->SetNetwork(network);
201  return rev;
202  }
203  int replicas = strtol(*str + 1, str, 10);
204  if (replicas <= 0) {
205  tprintf("Invalid R spec!:%s\n", *str);
206  return nullptr;
207  }
208  Parallel* parallel = new Parallel("Replicated", NT_REPLICATED);
209  char* str_copy = *str;
210  for (int i = 0; i < replicas; ++i) {
211  str_copy = *str;
212  Network* network = BuildFromString(input_shape, &str_copy);
213  if (network == nullptr) {
214  tprintf("Invalid replicated network!\n");
215  delete parallel;
216  return nullptr;
217  }
218  parallel->AddToStack(network);
219  }
220  *str = str_copy;
221  return parallel;
222 }
223 
224 // Parses a network that begins with 'S'.
225 Network* NetworkBuilder::ParseS(const StaticShape& input_shape, char** str) {
226  int y = strtol(*str + 1, str, 10);
227  if (**str == ',') {
228  int x = strtol(*str + 1, str, 10);
229  if (y <= 0 || x <= 0) {
230  tprintf("Invalid S spec!:%s\n", *str);
231  return nullptr;
232  }
233  return new Reconfig("Reconfig", input_shape.depth(), x, y);
234  } else if (**str == '(') {
235  // TODO(rays) Add Generic reshape.
236  tprintf("Generic reshape not yet implemented!!\n");
237  return nullptr;
238  }
239  tprintf("Invalid S spec!:%s\n", *str);
240  return nullptr;
241 }
242 
243 // Helper returns the fully-connected type for the character code.
244 static NetworkType NonLinearity(char func) {
245  switch (func) {
246  case 's':
247  return NT_LOGISTIC;
248  case 't':
249  return NT_TANH;
250  case 'r':
251  return NT_RELU;
252  case 'l':
253  return NT_LINEAR;
254  case 'm':
255  return NT_SOFTMAX;
256  case 'p':
257  return NT_POSCLIP;
258  case 'n':
259  return NT_SYMCLIP;
260  default:
261  return NT_NONE;
262  }
263 }
264 
265 // Parses a network that begins with 'C'.
266 Network* NetworkBuilder::ParseC(const StaticShape& input_shape, char** str) {
267  NetworkType type = NonLinearity((*str)[1]);
268  if (type == NT_NONE) {
269  tprintf("Invalid nonlinearity on C-spec!: %s\n", *str);
270  return nullptr;
271  }
272  int y = 0, x = 0, d = 0;
273  if ((y = strtol(*str + 2, str, 10)) <= 0 || **str != ',' ||
274  (x = strtol(*str + 1, str, 10)) <= 0 || **str != ',' ||
275  (d = strtol(*str + 1, str, 10)) <= 0) {
276  tprintf("Invalid C spec!:%s\n", *str);
277  return nullptr;
278  }
279  if (x == 1 && y == 1) {
280  // No actual convolution. Just a FullyConnected on the current depth, to
281  // be slid over all batch,y,x.
282  return new FullyConnected("Conv1x1", input_shape.depth(), d, type);
283  }
284  Series* series = new Series("ConvSeries");
285  Convolve* convolve =
286  new Convolve("Convolve", input_shape.depth(), x / 2, y / 2);
287  series->AddToStack(convolve);
288  StaticShape fc_input = convolve->OutputShape(input_shape);
289  series->AddToStack(new FullyConnected("ConvNL", fc_input.depth(), d, type));
290  return series;
291 }
292 
293 // Parses a network that begins with 'M'.
294 Network* NetworkBuilder::ParseM(const StaticShape& input_shape, char** str) {
295  int y = 0, x = 0;
296  if ((*str)[1] != 'p' || (y = strtol(*str + 2, str, 10)) <= 0 ||
297  **str != ',' || (x = strtol(*str + 1, str, 10)) <= 0) {
298  tprintf("Invalid Mp spec!:%s\n", *str);
299  return nullptr;
300  }
301  return new Maxpool("Maxpool", input_shape.depth(), x, y);
302 }
303 
304 // Parses an LSTM network, either individual, bi- or quad-directional.
305 Network* NetworkBuilder::ParseLSTM(const StaticShape& input_shape, char** str) {
306  bool two_d = false;
308  char* spec_start = *str;
309  int chars_consumed = 1;
310  int num_outputs = 0;
311  char key = (*str)[chars_consumed], dir = 'f', dim = 'x';
312  if (key == 'S') {
314  num_outputs = num_softmax_outputs_;
315  ++chars_consumed;
316  } else if (key == 'E') {
318  num_outputs = num_softmax_outputs_;
319  ++chars_consumed;
320  } else if (key == '2' && (((*str)[2] == 'x' && (*str)[3] == 'y') ||
321  ((*str)[2] == 'y' && (*str)[3] == 'x'))) {
322  chars_consumed = 4;
323  dim = (*str)[3];
324  two_d = true;
325  } else if (key == 'f' || key == 'r' || key == 'b') {
326  dir = key;
327  dim = (*str)[2];
328  if (dim != 'x' && dim != 'y') {
329  tprintf("Invalid dimension (x|y) in L Spec!:%s\n", *str);
330  return nullptr;
331  }
332  chars_consumed = 3;
333  if ((*str)[chars_consumed] == 's') {
334  ++chars_consumed;
336  }
337  } else {
338  tprintf("Invalid direction (f|r|b) in L Spec!:%s\n", *str);
339  return nullptr;
340  }
341  int num_states = strtol(*str + chars_consumed, str, 10);
342  if (num_states <= 0) {
343  tprintf("Invalid number of states in L Spec!:%s\n", *str);
344  return nullptr;
345  }
346  Network* lstm = nullptr;
347  if (two_d) {
348  lstm = BuildLSTMXYQuad(input_shape.depth(), num_states);
349  } else {
350  if (num_outputs == 0) num_outputs = num_states;
351  STRING name(spec_start, *str - spec_start);
352  lstm = new LSTM(name, input_shape.depth(), num_states, num_outputs, false,
353  type);
354  if (dir != 'f') {
355  Reversed* rev = new Reversed("RevLSTM", NT_XREVERSED);
356  rev->SetNetwork(lstm);
357  lstm = rev;
358  }
359  if (dir == 'b') {
360  name += "LTR";
361  Parallel* parallel = new Parallel("BidiLSTM", NT_PAR_RL_LSTM);
362  parallel->AddToStack(new LSTM(name, input_shape.depth(), num_states,
363  num_outputs, false, type));
364  parallel->AddToStack(lstm);
365  lstm = parallel;
366  }
367  }
368  if (dim == 'y') {
369  Reversed* rev = new Reversed("XYTransLSTM", NT_XYTRANSPOSE);
370  rev->SetNetwork(lstm);
371  lstm = rev;
372  }
373  return lstm;
374 }
375 
376 // Builds a set of 4 lstms with x and y reversal, running in true parallel.
377 Network* NetworkBuilder::BuildLSTMXYQuad(int num_inputs, int num_states) {
378  Parallel* parallel = new Parallel("2DLSTMQuad", NT_PAR_2D_LSTM);
379  parallel->AddToStack(new LSTM("L2DLTRDown", num_inputs, num_states,
380  num_states, true, NT_LSTM));
381  Reversed* rev = new Reversed("L2DLTRXRev", NT_XREVERSED);
382  rev->SetNetwork(new LSTM("L2DRTLDown", num_inputs, num_states, num_states,
383  true, NT_LSTM));
384  parallel->AddToStack(rev);
385  rev = new Reversed("L2DRTLYRev", NT_YREVERSED);
386  rev->SetNetwork(
387  new LSTM("L2DRTLUp", num_inputs, num_states, num_states, true, NT_LSTM));
388  Reversed* rev2 = new Reversed("L2DXRevU", NT_XREVERSED);
389  rev2->SetNetwork(rev);
390  parallel->AddToStack(rev2);
391  rev = new Reversed("L2DXRevY", NT_YREVERSED);
392  rev->SetNetwork(new LSTM("L2DLTRDown", num_inputs, num_states, num_states,
393  true, NT_LSTM));
394  parallel->AddToStack(rev);
395  return parallel;
396 }
397 
398 // Helper builds a truly (0-d) fully connected layer of the given type.
399 static Network* BuildFullyConnected(const StaticShape& input_shape,
400  NetworkType type, const STRING& name,
401  int depth) {
402  if (input_shape.height() == 0 || input_shape.width() == 0) {
403  tprintf("Fully connected requires positive height and width, had %d,%d\n",
404  input_shape.height(), input_shape.width());
405  return nullptr;
406  }
407  int input_size = input_shape.height() * input_shape.width();
408  int input_depth = input_size * input_shape.depth();
409  Network* fc = new FullyConnected(name, input_depth, depth, type);
410  if (input_size > 1) {
411  Series* series = new Series("FCSeries");
412  series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(),
413  input_shape.width(), input_shape.height()));
414  series->AddToStack(fc);
415  fc = series;
416  }
417  return fc;
418 }
419 
420 // Parses a Fully connected network.
421 Network* NetworkBuilder::ParseFullyConnected(const StaticShape& input_shape,
422  char** str) {
423  char* spec_start = *str;
424  NetworkType type = NonLinearity((*str)[1]);
425  if (type == NT_NONE) {
426  tprintf("Invalid nonlinearity on F-spec!: %s\n", *str);
427  return nullptr;
428  }
429  int depth = strtol(*str + 2, str, 10);
430  if (depth <= 0) {
431  tprintf("Invalid F spec!:%s\n", *str);
432  return nullptr;
433  }
434  STRING name(spec_start, *str - spec_start);
435  return BuildFullyConnected(input_shape, type, name, depth);
436 }
437 
438 // Parses an Output spec.
439 Network* NetworkBuilder::ParseOutput(const StaticShape& input_shape,
440  char** str) {
441  char dims_ch = (*str)[1];
442  if (dims_ch != '0' && dims_ch != '1' && dims_ch != '2') {
443  tprintf("Invalid dims (2|1|0) in output spec!:%s\n", *str);
444  return nullptr;
445  }
446  char type_ch = (*str)[2];
447  if (type_ch != 'l' && type_ch != 's' && type_ch != 'c') {
448  tprintf("Invalid output type (l|s|c) in output spec!:%s\n", *str);
449  return nullptr;
450  }
451  int depth = strtol(*str + 3, str, 10);
452  if (depth != num_softmax_outputs_) {
453  tprintf("Warning: given outputs %d not equal to unicharset of %d.\n", depth,
454  num_softmax_outputs_);
455  depth = num_softmax_outputs_;
456  }
458  if (type_ch == 'l')
459  type = NT_LOGISTIC;
460  else if (type_ch == 's')
462  if (dims_ch == '0') {
463  // Same as standard fully connected.
464  return BuildFullyConnected(input_shape, type, "Output", depth);
465  } else if (dims_ch == '2') {
466  // We don't care if x and/or y are variable.
467  return new FullyConnected("Output2d", input_shape.depth(), depth, type);
468  }
469  // For 1-d y has to be fixed, and if not 1, moved to depth.
470  if (input_shape.height() == 0) {
471  tprintf("Fully connected requires fixed height!\n");
472  return nullptr;
473  }
474  int input_size = input_shape.height();
475  int input_depth = input_size * input_shape.depth();
476  Network* fc = new FullyConnected("Output", input_depth, depth, type);
477  if (input_size > 1) {
478  Series* series = new Series("FCSeries");
479  series->AddToStack(new Reconfig("FCReconfig", input_shape.depth(), 1,
480  input_shape.height()));
481  series->AddToStack(fc);
482  fc = series;
483  }
484  return fc;
485 }
486 
487 } // namespace tesseract.
tesseract::Series::SplitAt
void SplitAt(int last_start, Series **start, Series **end)
Definition: series.cpp:159
tesseract::StaticShape
Definition: static_shape.h:38
tesseract::NT_PARALLEL
Definition: network.h:49
tesseract::NT_POSCLIP
Definition: network.h:63
tesseract::NT_PAR_2D_LSTM
Definition: network.h:53
tesseract::StaticShape::SetShape
void SetShape(int batch, int height, int width, int depth)
Definition: static_shape.h:52
tesseract::NT_XYTRANSPOSE
Definition: network.h:58
tesseract::NT_SOFTMAX_NO_CTC
Definition: network.h:69
tesseract::NT_PAR_RL_LSTM
Definition: network.h:51
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:87
tesseract::Series
Definition: series.h:27
STRING
Definition: strngs.h:45
parallel.h
network.h
tesseract::Network::type
NetworkType type() const
Definition: network.h:112
tesseract::NT_REPLICATED
Definition: network.h:50
tesseract::NetworkType
NetworkType
Definition: network.h:43
tesseract::NT_LSTM
Definition: network.h:60
tesseract::NT_SYMCLIP
Definition: network.h:64
tesseract::NetworkBuilder::BuildFromString
Network * BuildFromString(const StaticShape &input_shape, char **str)
Definition: networkbuilder.cpp:86
maxpool.h
tesseract::Series::OutputShape
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: series.cpp:34
tesseract::NT_SERIES
Definition: network.h:54
unicharset.h
tesseract::StaticShape::depth
int depth() const
Definition: static_shape.h:48
tesseract::NT_YREVERSED
Definition: network.h:57
networkbuilder.h
tesseract::NT_TANH
Definition: network.h:65
tesseract
Definition: baseapi.h:65
tesseract::Network::SetNetworkFlags
virtual void SetNetworkFlags(uint32_t flags)
Definition: network.cpp:124
tesseract::Series::CacheXScaleFactor
void CacheXScaleFactor(int factor) override
Definition: series.cpp:100
lstm.h
tesseract::NT_XREVERSED
Definition: network.h:56
tesseract::Series::AppendSeries
void AppendSeries(Network *src)
Definition: series.cpp:189
tesseract::NT_LSTM_SOFTMAX_ENCODED
Definition: network.h:76
tesseract::Network
Definition: network.h:105
series.h
reconfig.h
fullyconnected.h
tesseract::NT_RELU
Definition: network.h:66
tesseract::NT_NONE
Definition: network.h:44
tesseract::NT_LSTM_SOFTMAX
Definition: network.h:75
tesseract::NetworkBuilder::InitNetwork
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
Definition: networkbuilder.cpp:45
tesseract::NT_LSTM_SUMMARY
Definition: network.h:61
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
tesstrain_utils.type
type
Definition: tesstrain_utils.py:141
convolve.h
tesseract::NT_LINEAR
Definition: network.h:67
tesseract::NT_LOGISTIC
Definition: network.h:62
reversed.h
tesseract::TRand
Definition: helpers.h:50
tesseract::NetworkBuilder
Definition: networkbuilder.h:36
input.h
tesseract::NT_SOFTMAX
Definition: network.h:68