tesseract  5.0.0-alpha-619-ge9db
plumbing.cpp
Go to the documentation of this file.
1 // File: plumbing.cpp
3 // Description: Base class for networks that organize other networks
4 // eg series or parallel.
5 // Author: Ray Smith
6 // Created: Mon May 12 08:17:34 PST 2014
7 //
8 // (C) Copyright 2014, Google Inc.
9 // Licensed under the Apache License, Version 2.0 (the "License");
10 // you may not use this file except in compliance with the License.
11 // You may obtain a copy of the License at
12 // http://www.apache.org/licenses/LICENSE-2.0
13 // Unless required by applicable law or agreed to in writing, software
14 // distributed under the License is distributed on an "AS IS" BASIS,
15 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16 // See the License for the specific language governing permissions and
17 // limitations under the License.
19 
20 #include "plumbing.h"
21 
22 namespace tesseract {
23 
24 // ni_ and no_ will be set by AddToStack.
26  : Network(NT_PARALLEL, name, 0, 0) {
27 }
28 
29 // Suspends/Enables training by setting the training_ flag. Serialize and
30 // DeSerialize only operate on the run-time data if state is false.
33  for (int i = 0; i < stack_.size(); ++i)
34  stack_[i]->SetEnableTraining(state);
35 }
36 
37 // Sets flags that control the action of the network. See NetworkFlags enum
38 // for bit values.
39 void Plumbing::SetNetworkFlags(uint32_t flags) {
41  for (int i = 0; i < stack_.size(); ++i)
42  stack_[i]->SetNetworkFlags(flags);
43 }
44 
45 // Sets up the network for training. Initializes weights using weights of
46 // scale `range` picked according to the random number generator `randomizer`.
47 // Note that randomizer is a borrowed pointer that should outlive the network
48 // and should not be deleted by any of the networks.
49 // Returns the number of weights initialized.
50 int Plumbing::InitWeights(float range, TRand* randomizer) {
51  num_weights_ = 0;
52  for (int i = 0; i < stack_.size(); ++i)
53  num_weights_ += stack_[i]->InitWeights(range, randomizer);
54  return num_weights_;
55 }
56 
57 // Recursively searches the network for softmaxes with old_no outputs,
58 // and remaps their outputs according to code_map. See network.h for details.
59 int Plumbing::RemapOutputs(int old_no, const std::vector<int>& code_map) {
60  num_weights_ = 0;
61  for (int i = 0; i < stack_.size(); ++i) {
62  num_weights_ += stack_[i]->RemapOutputs(old_no, code_map);
63  }
64  return num_weights_;
65 }
66 
67 // Converts a float network to an int network.
69  for (int i = 0; i < stack_.size(); ++i)
70  stack_[i]->ConvertToInt();
71 }
72 
73 // Provides a pointer to a TRand for any networks that care to use it.
74 // Note that randomizer is a borrowed pointer that should outlive the network
75 // and should not be deleted by any of the networks.
76 void Plumbing::SetRandomizer(TRand* randomizer) {
77  for (int i = 0; i < stack_.size(); ++i)
78  stack_[i]->SetRandomizer(randomizer);
79 }
80 
81 // Adds the given network to the stack.
82 void Plumbing::AddToStack(Network* network) {
83  if (stack_.empty()) {
84  ni_ = network->NumInputs();
85  no_ = network->NumOutputs();
86  } else if (type_ == NT_SERIES) {
87  // ni is input of first, no output of last, others match output to input.
88  ASSERT_HOST(no_ == network->NumInputs());
89  no_ = network->NumOutputs();
90  } else {
91  // All parallel types. Output is sum of outputs, inputs all match.
92  ASSERT_HOST(ni_ == network->NumInputs());
93  no_ += network->NumOutputs();
94  }
95  stack_.push_back(network);
96 }
97 
98 // Sets needs_to_backprop_ to needs_backprop and calls on sub-network
99 // according to needs_backprop || any weights in this network.
100 bool Plumbing::SetupNeedsBackprop(bool needs_backprop) {
101  if (IsTraining()) {
102  needs_to_backprop_ = needs_backprop;
103  bool retval = needs_backprop;
104  for (int i = 0; i < stack_.size(); ++i) {
105  if (stack_[i]->SetupNeedsBackprop(needs_backprop)) retval = true;
106  }
107  return retval;
108  }
109  // Frozen networks don't do backprop.
110  needs_to_backprop_ = false;
111  return false;
112 }
113 
114 // Returns an integer reduction factor that the network applies to the
115 // time sequence. Assumes that any 2-d is already eliminated. Used for
116 // scaling bounding boxes of truth data.
117 // WARNING: if GlobalMinimax is used to vary the scale, this will return
118 // the last used scale factor. Call it before any forward, and it will return
119 // the minimum scale factor of the paths through the GlobalMinimax.
121  return stack_[0]->XScaleFactor();
122 }
123 
124 // Provides the (minimum) x scale factor to the network (of interest only to
125 // input units) so they can determine how to scale bounding boxes.
126 void Plumbing::CacheXScaleFactor(int factor) {
127  for (int i = 0; i < stack_.size(); ++i) {
128  stack_[i]->CacheXScaleFactor(factor);
129  }
130 }
131 
132 // Provides debug output on the weights.
134  for (int i = 0; i < stack_.size(); ++i)
135  stack_[i]->DebugWeights();
136 }
137 
138 // Returns a set of strings representing the layer-ids of all layers below.
140  GenericVector<STRING>* layers) const {
141  for (int i = 0; i < stack_.size(); ++i) {
142  STRING layer_name;
143  if (prefix) layer_name = *prefix;
144  layer_name.add_str_int(":", i);
145  if (stack_[i]->IsPlumbingType()) {
146  auto* plumbing = static_cast<Plumbing*>(stack_[i]);
147  plumbing->EnumerateLayers(&layer_name, layers);
148  } else {
149  layers->push_back(layer_name);
150  }
151  }
152 }
153 
154 // Returns a pointer to the network layer corresponding to the given id.
155 Network* Plumbing::GetLayer(const char* id) const {
156  char* next_id;
157  int index = strtol(id, &next_id, 10);
158  if (index < 0 || index >= stack_.size()) return nullptr;
159  if (stack_[index]->IsPlumbingType()) {
160  auto* plumbing = static_cast<Plumbing*>(stack_[index]);
161  ASSERT_HOST(*next_id == ':');
162  return plumbing->GetLayer(next_id + 1);
163  }
164  return stack_[index];
165 }
166 
167 // Returns a pointer to the learning rate for the given layer id.
168 float* Plumbing::LayerLearningRatePtr(const char* id) const {
169  char* next_id;
170  int index = strtol(id, &next_id, 10);
171  if (index < 0 || index >= stack_.size()) return nullptr;
172  if (stack_[index]->IsPlumbingType()) {
173  auto* plumbing = static_cast<Plumbing*>(stack_[index]);
174  ASSERT_HOST(*next_id == ':');
175  return plumbing->LayerLearningRatePtr(next_id + 1);
176  }
177  if (index >= learning_rates_.size()) return nullptr;
178  return &learning_rates_[index];
179 }
180 
181 // Writes to the given file. Returns false in case of error.
182 bool Plumbing::Serialize(TFile* fp) const {
183  if (!Network::Serialize(fp)) return false;
184  uint32_t size = stack_.size();
185  // Can't use PointerVector::Serialize here as we need a special DeSerialize.
186  if (!fp->Serialize(&size)) return false;
187  for (uint32_t i = 0; i < size; ++i)
188  if (!stack_[i]->Serialize(fp)) return false;
190  !learning_rates_.Serialize(fp)) {
191  return false;
192  }
193  return true;
194 }
195 
196 // Reads from the given file. Returns false in case of error.
198  stack_.truncate(0);
199  no_ = 0; // We will be modifying this as we AddToStack.
200  uint32_t size;
201  if (!fp->DeSerialize(&size)) return false;
202  for (uint32_t i = 0; i < size; ++i) {
203  Network* network = CreateFromFile(fp);
204  if (network == nullptr) return false;
205  AddToStack(network);
206  }
209  return false;
210  }
211  return true;
212 }
213 
214 // Updates the weights using the given learning rate, momentum and adam_beta.
215 // num_samples is used in the adam computation iff use_adam_ is true.
216 void Plumbing::Update(float learning_rate, float momentum, float adam_beta,
217  int num_samples) {
218  for (int i = 0; i < stack_.size(); ++i) {
220  if (i < learning_rates_.size())
221  learning_rate = learning_rates_[i];
222  else
223  learning_rates_.push_back(learning_rate);
224  }
225  if (stack_[i]->IsTraining()) {
226  stack_[i]->Update(learning_rate, momentum, adam_beta, num_samples);
227  }
228  }
229 }
230 
231 // Sums the products of weight updates in *this and other, splitting into
232 // positive (same direction) in *same and negative (different direction) in
233 // *changed.
234 void Plumbing::CountAlternators(const Network& other, double* same,
235  double* changed) const {
236  ASSERT_HOST(other.type() == type_);
237  const auto* plumbing = static_cast<const Plumbing*>(&other);
238  ASSERT_HOST(plumbing->stack_.size() == stack_.size());
239  for (int i = 0; i < stack_.size(); ++i)
240  stack_[i]->CountAlternators(*plumbing->stack_[i], same, changed);
241 }
242 
243 } // namespace tesseract.
tesseract::Plumbing::IsPlumbingType
bool IsPlumbingType() const override
Definition: plumbing.h:44
tesseract::NT_PARALLEL
Definition: network.h:49
STRING::add_str_int
void add_str_int(const char *str, int number)
Definition: strngs.cpp:370
tesseract::Plumbing::learning_rates_
GenericVector< float > learning_rates_
Definition: plumbing.h:139
tesseract::Network::SetEnableTraining
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:110
tesseract::Plumbing::AddToStack
virtual void AddToStack(Network *network)
Definition: plumbing.cpp:82
ASSERT_HOST
#define ASSERT_HOST(x)
Definition: errcode.h:87
tesseract::Plumbing::EnumerateLayers
void EnumerateLayers(const STRING *prefix, GenericVector< STRING > *layers) const
Definition: plumbing.cpp:139
tesseract::Plumbing::Plumbing
Plumbing(const STRING &name)
Definition: plumbing.cpp:25
STRING
Definition: strngs.h:45
GenericVector::Serialize
bool Serialize(FILE *fp) const
Definition: genericvector.h:929
tesseract::Network::type
NetworkType type() const
Definition: network.h:112
tesseract::Network::needs_to_backprop_
bool needs_to_backprop_
Definition: network.h:295
tesseract::Plumbing::CountAlternators
void CountAlternators(const Network &other, double *same, double *changed) const override
Definition: plumbing.cpp:234
tesseract::Plumbing::RemapOutputs
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
Definition: plumbing.cpp:59
tesseract::Plumbing::CacheXScaleFactor
void CacheXScaleFactor(int factor) override
Definition: plumbing.cpp:126
tesseract::Plumbing::stack_
PointerVector< Network > stack_
Definition: plumbing.h:136
tesseract::Network::IsTraining
bool IsTraining() const
Definition: network.h:115
tesseract::Plumbing::InitWeights
int InitWeights(float range, TRand *randomizer) override
Definition: plumbing.cpp:50
tesseract::Plumbing::SetRandomizer
void SetRandomizer(TRand *randomizer) override
Definition: plumbing.cpp:76
GenericVector::push_back
int push_back(T object)
Definition: genericvector.h:799
tesseract::Plumbing::SetEnableTraining
void SetEnableTraining(TrainingState state) override
Definition: plumbing.cpp:31
tesseract::Network::CreateFromFile
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:187
tesseract::NT_SERIES
Definition: network.h:54
tesseract::Network::type_
NetworkType type_
Definition: network.h:293
tesseract::TFile::DeSerialize
bool DeSerialize(char *data, size_t count=1)
Definition: serialis.cpp:117
GenericVector::DeSerialize
bool DeSerialize(bool swap, FILE *fp)
Definition: genericvector.h:954
tesseract::TFile::Serialize
bool Serialize(const char *data, size_t count=1)
Definition: serialis.cpp:161
tesseract::TFile
Definition: serialis.h:75
tesseract::Plumbing::ConvertToInt
void ConvertToInt() override
Definition: plumbing.cpp:68
tesseract::Plumbing::LayerLearningRatePtr
float * LayerLearningRatePtr(const char *id) const
Definition: plumbing.cpp:168
tesseract::Plumbing::GetLayer
Network * GetLayer(const char *id) const
Definition: plumbing.cpp:155
tesseract
Definition: baseapi.h:65
tesseract::Network::SetNetworkFlags
virtual void SetNetworkFlags(uint32_t flags)
Definition: network.cpp:124
tesseract::Plumbing::Update
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
Definition: plumbing.cpp:216
tesseract::Network::NumOutputs
int NumOutputs() const
Definition: network.h:123
GenericVector< STRING >
tesseract::Plumbing::SetupNeedsBackprop
bool SetupNeedsBackprop(bool needs_backprop) override
Definition: plumbing.cpp:100
tesseract::Plumbing::DebugWeights
void DebugWeights() override
Definition: plumbing.cpp:133
tesseract::Network
Definition: network.h:105
tesseract::Network::num_weights_
int32_t num_weights_
Definition: network.h:299
tesseract::Plumbing::DeSerialize
bool DeSerialize(TFile *fp) override
Definition: plumbing.cpp:197
tesseract::TrainingState
TrainingState
Definition: network.h:92
tesseract::Plumbing::XScaleFactor
int XScaleFactor() const override
Definition: plumbing.cpp:120
tesseract::NF_LAYER_SPECIFIC_LR
Definition: network.h:87
tesseract::Network::Serialize
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
tesseract::Network::NumInputs
int NumInputs() const
Definition: network.h:120
tesseract::Network::no_
int32_t no_
Definition: network.h:298
tesseract::Network::ni_
int32_t ni_
Definition: network.h:297
GenericVector::size
int size() const
Definition: genericvector.h:71
tesseract::TRand
Definition: helpers.h:50
tesseract::Plumbing::SetNetworkFlags
void SetNetworkFlags(uint32_t flags) override
Definition: plumbing.cpp:39
tesseract::Plumbing::Serialize
bool Serialize(TFile *fp) const override
Definition: plumbing.cpp:182
plumbing.h
tesseract::Network::network_flags_
int32_t network_flags_
Definition: network.h:296