tesseract  4.0.0-1-g2a2b
parallel.h
Go to the documentation of this file.
1 // File: parallel.h
3 // Description: Runs networks in parallel on the same input.
4 // Author: Ray Smith
5 // Created: Thu May 02 08:02:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_PARALLEL_H_
20 #define TESSERACT_LSTM_PARALLEL_H_
21 
22 #include "plumbing.h"
23 
24 namespace tesseract {
25 
26 // Runs multiple networks in parallel, interlacing their outputs.
27 class Parallel : public Plumbing {
28  public:
29  // ni_ and no_ will be set by AddToStack.
31  virtual ~Parallel() = default;
32 
33  // Returns the shape output from the network given an input shape (which may
34  // be partially unknown ie zero).
35  StaticShape OutputShape(const StaticShape& input_shape) const override;
36 
37  STRING spec() const override {
38  STRING spec;
39  if (type_ == NT_PAR_2D_LSTM) {
40  // We have 4 LSTMs operating in parallel here, so the size of each is
41  // the number of outputs/4.
42  spec.add_str_int("L2xy", no_ / 4);
43  } else if (type_ == NT_PAR_RL_LSTM) {
44  // We have 2 LSTMs operating in parallel here, so the size of each is
45  // the number of outputs/2.
46  if (stack_[0]->type() == NT_LSTM_SUMMARY)
47  spec.add_str_int("Lbxs", no_ / 2);
48  else
49  spec.add_str_int("Lbx", no_ / 2);
50  } else {
51  if (type_ == NT_REPLICATED) {
52  spec.add_str_int("R", stack_.size());
53  spec += "(";
54  spec += stack_[0]->spec();
55  } else {
56  spec = "(";
57  for (int i = 0; i < stack_.size(); ++i) spec += stack_[i]->spec();
58  }
59  spec += ")";
60  }
61  return spec;
62  }
63 
64  // Runs forward propagation of activations on the input line.
65  // See Network for a detailed discussion of the arguments.
66  void Forward(bool debug, const NetworkIO& input,
67  const TransposedArray* input_transpose,
68  NetworkScratch* scratch, NetworkIO* output) override;
69 
70  // Runs backward propagation of errors on the deltas line.
71  // See Network for a detailed discussion of the arguments.
72  bool Backward(bool debug, const NetworkIO& fwd_deltas,
73  NetworkScratch* scratch,
74  NetworkIO* back_deltas) override;
75 
76  private:
77  // If *this is a NT_REPLICATED, then it feeds a replicated network with
78  // identical inputs, and it would be extremely wasteful for them to each
79  // calculate and store the same transpose of the inputs, so Parallel does it
80  // and passes a pointer to the replicated network, allowing it to use the
81  // transpose on the next call to Backward.
82  TransposedArray transposed_input_;
83 };
84 
85 } // namespace tesseract.
86 
87 #endif // TESSERACT_LSTM_PARALLEL_H_
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: parallel.cpp:110
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: parallel.cpp:49
STRING spec() const override
Definition: parallel.h:37
virtual ~Parallel()=default
NetworkType
Definition: network.h:43
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: parallel.cpp:37
Parallel(const STRING &name, NetworkType type)
Definition: parallel.cpp:31
NetworkType type_
Definition: network.h:299
PointerVector< Network > stack_
Definition: plumbing.h:136
const STRING & name() const
Definition: network.h:138
NetworkType type() const
Definition: network.h:112
void add_str_int(const char *str, int number)
Definition: strngs.cpp:379
Definition: strngs.h:45