tesseract
5.0.0-alpha-619-ge9db
parallel.h
Go to the documentation of this file.
1
// File: parallel.h
3
// Description: Runs networks in parallel on the same input.
4
// Author: Ray Smith
5
// Created: Thu May 02 08:02:06 PST 2013
6
//
7
// (C) Copyright 2013, Google Inc.
8
// Licensed under the Apache License, Version 2.0 (the "License");
9
// you may not use this file except in compliance with the License.
10
// You may obtain a copy of the License at
11
// http://www.apache.org/licenses/LICENSE-2.0
12
// Unless required by applicable law or agreed to in writing, software
13
// distributed under the License is distributed on an "AS IS" BASIS,
14
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
// See the License for the specific language governing permissions and
16
// limitations under the License.
18
19
#ifndef TESSERACT_LSTM_PARALLEL_H_
20
#define TESSERACT_LSTM_PARALLEL_H_
21
22
#include "
plumbing.h
"
23
24
namespace
tesseract
{
25
26
// Runs multiple networks in parallel, interlacing their outputs.
27
class
Parallel
:
public
Plumbing
{
28
public
:
29
// ni_ and no_ will be set by AddToStack.
30
Parallel
(
const
STRING
&
name
,
NetworkType
type
);
31
~Parallel
()
override
=
default
;
32
33
// Returns the shape output from the network given an input shape (which may
34
// be partially unknown ie zero).
35
StaticShape
OutputShape
(
const
StaticShape
& input_shape)
const override
;
36
37
STRING
spec
()
const override
{
38
STRING
spec
;
39
if
(
type_
==
NT_PAR_2D_LSTM
) {
40
// We have 4 LSTMs operating in parallel here, so the size of each is
41
// the number of outputs/4.
42
spec
.
add_str_int
(
"L2xy"
,
no_
/ 4);
43
}
else
if
(
type_
==
NT_PAR_RL_LSTM
) {
44
// We have 2 LSTMs operating in parallel here, so the size of each is
45
// the number of outputs/2.
46
if
(
stack_
[0]->
type
() ==
NT_LSTM_SUMMARY
)
47
spec
.
add_str_int
(
"Lbxs"
,
no_
/ 2);
48
else
49
spec
.
add_str_int
(
"Lbx"
,
no_
/ 2);
50
}
else
{
51
if
(
type_
==
NT_REPLICATED
) {
52
spec
.
add_str_int
(
"R"
,
stack_
.size());
53
spec
+=
"("
;
54
spec
+=
stack_
[0]->spec();
55
}
else
{
56
spec
=
"("
;
57
for
(
int
i = 0; i <
stack_
.size(); ++i)
spec
+=
stack_
[i]->
spec
();
58
}
59
spec
+=
")"
;
60
}
61
return
spec
;
62
}
63
64
// Runs forward propagation of activations on the input line.
65
// See Network for a detailed discussion of the arguments.
66
void
Forward
(
bool
debug,
const
NetworkIO
& input,
67
const
TransposedArray
* input_transpose,
68
NetworkScratch
* scratch,
NetworkIO
* output)
override
;
69
70
// Runs backward propagation of errors on the deltas line.
71
// See Network for a detailed discussion of the arguments.
72
bool
Backward
(
bool
debug,
const
NetworkIO
& fwd_deltas,
73
NetworkScratch
* scratch,
74
NetworkIO
* back_deltas)
override
;
75
76
private
:
77
// If *this is a NT_REPLICATED, then it feeds a replicated network with
78
// identical inputs, and it would be extremely wasteful for them to each
79
// calculate and store the same transpose of the inputs, so Parallel does it
80
// and passes a pointer to the replicated network, allowing it to use the
81
// transpose on the next call to Backward.
82
TransposedArray
transposed_input_;
83
};
84
85
}
// namespace tesseract.
86
87
#endif // TESSERACT_LSTM_PARALLEL_H_
tesseract::StaticShape
Definition:
static_shape.h:38
tesseract::Parallel::Forward
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition:
parallel.cpp:49
tesseract::NT_PAR_2D_LSTM
Definition:
network.h:53
tesseract::Parallel::Backward
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition:
parallel.cpp:110
STRING::add_str_int
void add_str_int(const char *str, int number)
Definition:
strngs.cpp:370
tesseract::NT_PAR_RL_LSTM
Definition:
network.h:51
tesseract::Parallel::~Parallel
~Parallel() override=default
tesseract::Parallel
Definition:
parallel.h:27
tesseract::Parallel::OutputShape
StaticShape OutputShape(const StaticShape &input_shape) const override
Definition:
parallel.cpp:37
STRING
Definition:
strngs.h:45
tesseract::NetworkScratch
Definition:
networkscratch.h:34
tesseract::Network::type
NetworkType type() const
Definition:
network.h:112
tesseract::NT_REPLICATED
Definition:
network.h:50
tesseract::NetworkType
NetworkType
Definition:
network.h:43
tesseract::Plumbing::stack_
PointerVector< Network > stack_
Definition:
plumbing.h:136
tesseract::Network::type_
NetworkType type_
Definition:
network.h:293
tesseract::NetworkIO
Definition:
networkio.h:39
tesseract::Plumbing
Definition:
plumbing.h:30
tesseract
Definition:
baseapi.h:65
tesseract::Parallel::spec
STRING spec() const override
Definition:
parallel.h:37
tesseract::Network::name
const STRING & name() const
Definition:
network.h:138
tesseract::TransposedArray
Definition:
weightmatrix.h:32
tesseract::NT_LSTM_SUMMARY
Definition:
network.h:61
tesseract::Parallel::Parallel
Parallel(const STRING &name, NetworkType type)
Definition:
parallel.cpp:31
tesseract::Network::no_
int32_t no_
Definition:
network.h:298
plumbing.h
src
lstm
parallel.h
Generated on Thu Jan 30 2020 14:22:20 for tesseract by
1.8.16