tesseract  5.0.0-alpha-619-ge9db
parallel.h
Go to the documentation of this file.
1 // File: parallel.h
3 // Description: Runs networks in parallel on the same input.
4 // Author: Ray Smith
5 // Created: Thu May 02 08:02:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_PARALLEL_H_
20 #define TESSERACT_LSTM_PARALLEL_H_
21 
22 #include "plumbing.h"
23 
24 namespace tesseract {
25 
26 // Runs multiple networks in parallel, interlacing their outputs.
27 class Parallel : public Plumbing {
28  public:
29  // ni_ and no_ will be set by AddToStack.
31  ~Parallel() override = default;
32 
33  // Returns the shape output from the network given an input shape (which may
34  // be partially unknown ie zero).
35  StaticShape OutputShape(const StaticShape& input_shape) const override;
36 
37  STRING spec() const override {
38  STRING spec;
39  if (type_ == NT_PAR_2D_LSTM) {
40  // We have 4 LSTMs operating in parallel here, so the size of each is
41  // the number of outputs/4.
42  spec.add_str_int("L2xy", no_ / 4);
43  } else if (type_ == NT_PAR_RL_LSTM) {
44  // We have 2 LSTMs operating in parallel here, so the size of each is
45  // the number of outputs/2.
46  if (stack_[0]->type() == NT_LSTM_SUMMARY)
47  spec.add_str_int("Lbxs", no_ / 2);
48  else
49  spec.add_str_int("Lbx", no_ / 2);
50  } else {
51  if (type_ == NT_REPLICATED) {
52  spec.add_str_int("R", stack_.size());
53  spec += "(";
54  spec += stack_[0]->spec();
55  } else {
56  spec = "(";
57  for (int i = 0; i < stack_.size(); ++i) spec += stack_[i]->spec();
58  }
59  spec += ")";
60  }
61  return spec;
62  }
63 
64  // Runs forward propagation of activations on the input line.
65  // See Network for a detailed discussion of the arguments.
66  void Forward(bool debug, const NetworkIO& input,
67  const TransposedArray* input_transpose,
68  NetworkScratch* scratch, NetworkIO* output) override;
69 
70  // Runs backward propagation of errors on the deltas line.
71  // See Network for a detailed discussion of the arguments.
72  bool Backward(bool debug, const NetworkIO& fwd_deltas,
73  NetworkScratch* scratch,
74  NetworkIO* back_deltas) override;
75 
76  private:
77  // If *this is a NT_REPLICATED, then it feeds a replicated network with
78  // identical inputs, and it would be extremely wasteful for them to each
79  // calculate and store the same transpose of the inputs, so Parallel does it
80  // and passes a pointer to the replicated network, allowing it to use the
81  // transpose on the next call to Backward.
82  TransposedArray transposed_input_;
83 };
84 
85 } // namespace tesseract.
86 
87 #endif // TESSERACT_LSTM_PARALLEL_H_
tesseract::StaticShape
Definition: static_shape.h:38
tesseract::Parallel::Forward
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: parallel.cpp:49
tesseract::NT_PAR_2D_LSTM
Definition: network.h:53
tesseract::Parallel::Backward
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: parallel.cpp:110
STRING::add_str_int
void add_str_int(const char *str, int number)
Definition: strngs.cpp:370
tesseract::NT_PAR_RL_LSTM
Definition: network.h:51
tesseract::Parallel::~Parallel
~Parallel() override=default
tesseract::Parallel
Definition: parallel.h:27
tesseract::Parallel::OutputShape
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition: parallel.cpp:37
STRING
Definition: strngs.h:45
tesseract::NetworkScratch
Definition: networkscratch.h:34
tesseract::Network::type
NetworkType type() const
Definition: network.h:112
tesseract::NT_REPLICATED
Definition: network.h:50
tesseract::NetworkType
NetworkType
Definition: network.h:43
tesseract::Plumbing::stack_
PointerVector< Network > stack_
Definition: plumbing.h:136
tesseract::Network::type_
NetworkType type_
Definition: network.h:293
tesseract::NetworkIO
Definition: networkio.h:39
tesseract::Plumbing
Definition: plumbing.h:30
tesseract
Definition: baseapi.h:65
tesseract::Parallel::spec
STRING spec() const override
Definition: parallel.h:37
tesseract::Network::name
const STRING & name() const
Definition: network.h:138
tesseract::TransposedArray
Definition: weightmatrix.h:32
tesseract::NT_LSTM_SUMMARY
Definition: network.h:61
tesseract::Parallel::Parallel
Parallel(const STRING &name, NetworkType type)
Definition: parallel.cpp:31
tesseract::Network::no_
int32_t no_
Definition: network.h:298
plumbing.h