tesseract  4.0.0-1-g2a2b
tesseract::FullyConnected Class Reference

#include <fullyconnected.h>

Inheritance diagram for tesseract::FullyConnected:
tesseract::Network

Public Member Functions

 FullyConnected (const STRING &name, int ni, int no, NetworkType type)
 
virtual ~FullyConnected ()=default
 
StaticShape OutputShape (const StaticShape &input_shape) const override
 
STRING spec () const override
 
void ChangeType (NetworkType type)
 
void SetEnableTraining (TrainingState state) override
 
int InitWeights (float range, TRand *randomizer) override
 
int RemapOutputs (int old_no, const std::vector< int > &code_map) override
 
void ConvertToInt () override
 
void DebugWeights () override
 
bool Serialize (TFile *fp) const override
 
bool DeSerialize (TFile *fp) override
 
void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
 
void SetupForward (const NetworkIO &input, const TransposedArray *input_transpose)
 
void ForwardTimeStep (int t, double *output_line)
 
void ForwardTimeStep (const double *d_input, int t, double *output_line)
 
void ForwardTimeStep (const int8_t *i_input, int t, double *output_line)
 
bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
 
void BackwardTimeStep (const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
 
void FinishBackward (const TransposedArray &errors_t)
 
void Update (float learning_rate, float momentum, float adam_beta, int num_samples) override
 
void CountAlternators (const Network &other, double *same, double *changed) const override
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()=default
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
const STRINGname () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetNetworkFlags (uint32_t flags)
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor (int factor)
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Protected Attributes

WeightMatrix weights_
 
TransposedArray source_t_
 
const TransposedArrayexternal_source_
 
NetworkIO acts_
 
bool int_mode_
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
int32_t network_flags_
 
int32_t ni_
 
int32_t no_
 
int32_t num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 

Additional Inherited Members

- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 
- Protected Member Functions inherited from tesseract::Network
double Random (double range)
 
- Static Protected Attributes inherited from tesseract::Network
static char const *const kTypeNames [NT_COUNT]
 

Detailed Description

Definition at line 28 of file fullyconnected.h.

Constructor & Destructor Documentation

◆ FullyConnected()

tesseract::FullyConnected::FullyConnected ( const STRING name,
int  ni,
int  no,
NetworkType  type 
)

Definition at line 39 of file fullyconnected.cpp.

41  : Network(type, name, ni, no), external_source_(nullptr), int_mode_(false) {
42 }
const TransposedArray * external_source_
const STRING & name() const
Definition: network.h:138
NetworkType type() const
Definition: network.h:112

◆ ~FullyConnected()

virtual tesseract::FullyConnected::~FullyConnected ( )
virtualdefault

Member Function Documentation

◆ Backward()

bool tesseract::FullyConnected::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 221 of file fullyconnected.cpp.

223  {
224  if (debug) DisplayBackward(fwd_deltas);
225  back_deltas->Resize(fwd_deltas, ni_);
227  errors.init_to_size(kNumThreads, NetworkScratch::FloatVec());
228  for (int i = 0; i < kNumThreads; ++i) errors[i].Init(no_, scratch);
230  if (needs_to_backprop_) {
231  temp_backprops.init_to_size(kNumThreads, NetworkScratch::FloatVec());
232  for (int i = 0; i < kNumThreads; ++i) temp_backprops[i].Init(ni_, scratch);
233  }
234  int width = fwd_deltas.Width();
235  NetworkScratch::GradientStore errors_t;
236  errors_t.Init(no_, width, scratch);
237 #ifdef _OPENMP
238 #pragma omp parallel for num_threads(kNumThreads)
239  for (int t = 0; t < width; ++t) {
240  int thread_id = omp_get_thread_num();
241 #else
242  for (int t = 0; t < width; ++t) {
243  int thread_id = 0;
244 #endif
245  double* backprop = nullptr;
246  if (needs_to_backprop_) backprop = temp_backprops[thread_id];
247  double* curr_errors = errors[thread_id];
248  BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
249  if (backprop != nullptr) {
250  back_deltas->WriteTimeStep(t, backprop);
251  }
252  }
253  FinishBackward(*errors_t.get());
254  if (needs_to_backprop_) {
255  back_deltas->ZeroInvalidElements();
256 #if DEBUG_DETAIL > 0
257  tprintf("F Backprop:%s\n", name_.string());
258  back_deltas->Print(10);
259 #endif
260  return true;
261  }
262  return false; // No point going further back.
263 }
const char * string() const
Definition: strngs.cpp:196
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:293
void init_to_size(int size, const T &t)
bool needs_to_backprop_
Definition: network.h:301
const int kNumThreads
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:37
void FinishBackward(const TransposedArray &errors_t)

◆ BackwardTimeStep()

void tesseract::FullyConnected::BackwardTimeStep ( const NetworkIO fwd_deltas,
int  t,
double *  curr_errors,
TransposedArray errors_t,
double *  backprop 
)

Definition at line 265 of file fullyconnected.cpp.

268  {
269  if (type_ == NT_TANH)
270  acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
271  else if (type_ == NT_LOGISTIC)
272  acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
273  else if (type_ == NT_POSCLIP)
274  acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
275  else if (type_ == NT_SYMCLIP)
276  acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
277  else if (type_ == NT_RELU)
278  acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
279  else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC ||
280  type_ == NT_LINEAR)
281  fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
282  else
283  ASSERT_HOST("Invalid fully-connected type!" == nullptr);
284  // Generate backprop only if needed by the lower layer.
285  if (backprop != nullptr) weights_.VectorDotMatrix(curr_errors, backprop);
286  errors_t->WriteStrided(t, curr_errors);
287 }
void VectorDotMatrix(const double *u, double *v) const
NetworkType type_
Definition: network.h:299
void FuncMultiply(const NetworkIO &v_io, int t, double *product)
Definition: networkio.h:259
#define ASSERT_HOST(x)
Definition: errcode.h:84

◆ ChangeType()

void tesseract::FullyConnected::ChangeType ( NetworkType  type)
inline

Definition at line 60 of file fullyconnected.h.

60  {
61  type_ = type;
62  }
NetworkType type_
Definition: network.h:299
NetworkType type() const
Definition: network.h:112

◆ ConvertToInt()

void tesseract::FullyConnected::ConvertToInt ( )
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 96 of file fullyconnected.cpp.

96  {
98 }

◆ CountAlternators()

void tesseract::FullyConnected::CountAlternators ( const Network other,
double *  same,
double *  changed 
) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 306 of file fullyconnected.cpp.

307  {
308  ASSERT_HOST(other.type() == type_);
309  const FullyConnected* fc = static_cast<const FullyConnected*>(&other);
310  weights_.CountAlternators(fc->weights_, same, changed);
311 }
FullyConnected(const STRING &name, int ni, int no, NetworkType type)
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
NetworkType type_
Definition: network.h:299
#define ASSERT_HOST(x)
Definition: errcode.h:84

◆ DebugWeights()

void tesseract::FullyConnected::DebugWeights ( )
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 101 of file fullyconnected.cpp.

101  {
103 }
const char * string() const
Definition: strngs.cpp:196
void Debug2D(const char *msg)

◆ DeSerialize()

bool tesseract::FullyConnected::DeSerialize ( TFile fp)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 113 of file fullyconnected.cpp.

113  {
114  return weights_.DeSerialize(IsTraining(), fp);
115 }
bool IsTraining() const
Definition: network.h:115
bool DeSerialize(bool training, TFile *fp)

◆ FinishBackward()

void tesseract::FullyConnected::FinishBackward ( const TransposedArray errors_t)

Definition at line 289 of file fullyconnected.cpp.

289  {
290  if (external_source_ == nullptr)
291  weights_.SumOuterTransposed(errors_t, source_t_, true);
292  else
293  weights_.SumOuterTransposed(errors_t, *external_source_, true);
294 }
const TransposedArray * external_source_
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)

◆ Forward()

void tesseract::FullyConnected::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 119 of file fullyconnected.cpp.

121  {
122  int width = input.Width();
123  if (type_ == NT_SOFTMAX)
124  output->ResizeFloat(input, no_);
125  else
126  output->Resize(input, no_);
127  SetupForward(input, input_transpose);
129  temp_lines.init_to_size(kNumThreads, NetworkScratch::FloatVec());
131  curr_input.init_to_size(kNumThreads, NetworkScratch::FloatVec());
132  for (int i = 0; i < kNumThreads; ++i) {
133  temp_lines[i].Init(no_, scratch);
134  curr_input[i].Init(ni_, scratch);
135  }
136 #ifdef _OPENMP
137 #pragma omp parallel for num_threads(kNumThreads)
138  for (int t = 0; t < width; ++t) {
139  // Thread-local pointer to temporary storage.
140  int thread_id = omp_get_thread_num();
141 #else
142  for (int t = 0; t < width; ++t) {
143  // Thread-local pointer to temporary storage.
144  int thread_id = 0;
145 #endif
146  double* temp_line = temp_lines[thread_id];
147  if (input.int_mode()) {
148  ForwardTimeStep(input.i(t), t, temp_line);
149  } else {
150  input.ReadTimeStep(t, curr_input[thread_id]);
151  ForwardTimeStep(curr_input[thread_id], t, temp_line);
152  }
153  output->WriteTimeStep(t, temp_line);
154  if (IsTraining() && type_ != NT_SOFTMAX) {
155  acts_.CopyTimeStepFrom(t, *output, t);
156  }
157  }
158  // Zero all the elements that are in the padding around images that allows
159  // multiple different-sized images to exist in a single array.
160  // acts_ is only used if this is not a softmax op.
161  if (IsTraining() && type_ != NT_SOFTMAX) {
163  }
164  output->ZeroInvalidElements();
165 #if DEBUG_DETAIL > 0
166  tprintf("F Output:%s\n", name_.string());
167  output->Print(10);
168 #endif
169  if (debug) DisplayForward(*output);
170 }
const char * string() const
Definition: strngs.cpp:196
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
Definition: networkio.cpp:388
NetworkType type_
Definition: network.h:299
void ZeroInvalidElements()
Definition: networkio.cpp:93
void init_to_size(int size, const T &t)
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:282
const int kNumThreads
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:37
void ForwardTimeStep(int t, double *output_line)
bool IsTraining() const
Definition: network.h:115
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)

◆ ForwardTimeStep() [1/3]

void tesseract::FullyConnected::ForwardTimeStep ( int  t,
double *  output_line 
)

Definition at line 185 of file fullyconnected.cpp.

185  {
186  if (type_ == NT_TANH) {
187  FuncInplace<GFunc>(no_, output_line);
188  } else if (type_ == NT_LOGISTIC) {
189  FuncInplace<FFunc>(no_, output_line);
190  } else if (type_ == NT_POSCLIP) {
191  FuncInplace<ClipFFunc>(no_, output_line);
192  } else if (type_ == NT_SYMCLIP) {
193  FuncInplace<ClipGFunc>(no_, output_line);
194  } else if (type_ == NT_RELU) {
195  FuncInplace<Relu>(no_, output_line);
196  } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
197  SoftmaxInPlace(no_, output_line);
198  } else if (type_ != NT_LINEAR) {
199  ASSERT_HOST("Invalid fully-connected type!" == nullptr);
200  }
201 }
void SoftmaxInPlace(int n, T *inout)
Definition: functions.h:163
NetworkType type_
Definition: network.h:299
#define ASSERT_HOST(x)
Definition: errcode.h:84

◆ ForwardTimeStep() [2/3]

void tesseract::FullyConnected::ForwardTimeStep ( const double *  d_input,
int  t,
double *  output_line 
)

Definition at line 203 of file fullyconnected.cpp.

204  {
205  // input is copied to source_ line-by-line for cache coherency.
206  if (IsTraining() && external_source_ == nullptr)
207  source_t_.WriteStrided(t, d_input);
208  weights_.MatrixDotVector(d_input, output_line);
209  ForwardTimeStep(t, output_line);
210 }
const TransposedArray * external_source_
void MatrixDotVector(const double *u, double *v) const
void ForwardTimeStep(int t, double *output_line)
bool IsTraining() const
Definition: network.h:115
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:40

◆ ForwardTimeStep() [3/3]

void tesseract::FullyConnected::ForwardTimeStep ( const int8_t *  i_input,
int  t,
double *  output_line 
)

Definition at line 212 of file fullyconnected.cpp.

213  {
214  // input is copied to source_ line-by-line for cache coherency.
215  weights_.MatrixDotVector(i_input, output_line);
216  ForwardTimeStep(t, output_line);
217 }
void MatrixDotVector(const double *u, double *v) const
void ForwardTimeStep(int t, double *output_line)

◆ InitWeights()

int tesseract::FullyConnected::InitWeights ( float  range,
TRand randomizer 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 77 of file fullyconnected.cpp.

77  {
78  Network::SetRandomizer(randomizer);
80  range, randomizer);
81  return num_weights_;
82 }
int32_t num_weights_
Definition: network.h:305
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:138
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144

◆ OutputShape()

StaticShape tesseract::FullyConnected::OutputShape ( const StaticShape input_shape) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 46 of file fullyconnected.cpp.

46  {
47  LossType loss_type = LT_NONE;
48  if (type_ == NT_SOFTMAX)
49  loss_type = LT_CTC;
50  else if (type_ == NT_SOFTMAX_NO_CTC)
51  loss_type = LT_SOFTMAX;
52  else if (type_ == NT_LOGISTIC)
53  loss_type = LT_LOGISTIC;
54  StaticShape result(input_shape);
55  result.set_depth(no_);
56  result.set_loss_type(loss_type);
57  return result;
58 }
NetworkType type_
Definition: network.h:299

◆ RemapOutputs()

int tesseract::FullyConnected::RemapOutputs ( int  old_no,
const std::vector< int > &  code_map 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 87 of file fullyconnected.cpp.

87  {
88  if (type_ == NT_SOFTMAX && no_ == old_no) {
90  no_ = code_map.size();
91  }
92  return num_weights_;
93 }
int32_t num_weights_
Definition: network.h:305
int RemapOutputs(const std::vector< int > &code_map)
NetworkType type_
Definition: network.h:299

◆ Serialize()

bool tesseract::FullyConnected::Serialize ( TFile fp) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 106 of file fullyconnected.cpp.

106  {
107  if (!Network::Serialize(fp)) return false;
108  if (!weights_.Serialize(IsTraining(), fp)) return false;
109  return true;
110 }
bool Serialize(bool training, TFile *fp) const
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
bool IsTraining() const
Definition: network.h:115

◆ SetEnableTraining()

void tesseract::FullyConnected::SetEnableTraining ( TrainingState  state)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 61 of file fullyconnected.cpp.

61  {
62  if (state == TS_RE_ENABLE) {
63  // Enable only from temp disabled.
65  } else if (state == TS_TEMP_DISABLE) {
66  // Temp disable only from enabled.
67  if (training_ == TS_ENABLED) training_ = state;
68  } else {
69  if (state == TS_ENABLED && training_ != TS_ENABLED)
71  training_ = state;
72  }
73 }
TrainingState training_
Definition: network.h:300

◆ SetupForward()

void tesseract::FullyConnected::SetupForward ( const NetworkIO input,
const TransposedArray input_transpose 
)

Definition at line 173 of file fullyconnected.cpp.

174  {
175  // Softmax output is always float, so save the input type.
176  int_mode_ = input.int_mode();
177  if (IsTraining()) {
178  acts_.Resize(input, no_);
179  // Source_ is a transposed copy of input. It isn't needed if provided.
180  external_source_ = input_transpose;
181  if (external_source_ == nullptr) source_t_.ResizeNoInit(ni_, input.Width());
182  }
183 }
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:91
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
const TransposedArray * external_source_
bool IsTraining() const
Definition: network.h:115

◆ spec()

STRING tesseract::FullyConnected::spec ( ) const
inlineoverridevirtual

Reimplemented from tesseract::Network.

Definition at line 37 of file fullyconnected.h.

37  {
38  STRING spec;
39  if (type_ == NT_TANH)
40  spec.add_str_int("Ft", no_);
41  else if (type_ == NT_LOGISTIC)
42  spec.add_str_int("Fs", no_);
43  else if (type_ == NT_RELU)
44  spec.add_str_int("Fr", no_);
45  else if (type_ == NT_LINEAR)
46  spec.add_str_int("Fl", no_);
47  else if (type_ == NT_POSCLIP)
48  spec.add_str_int("Fp", no_);
49  else if (type_ == NT_SYMCLIP)
50  spec.add_str_int("Fs", no_);
51  else if (type_ == NT_SOFTMAX)
52  spec.add_str_int("Fc", no_);
53  else
54  spec.add_str_int("Fm", no_);
55  return spec;
56  }
NetworkType type_
Definition: network.h:299
void add_str_int(const char *str, int number)
Definition: strngs.cpp:379
STRING spec() const override
Definition: strngs.h:45

◆ Update()

void tesseract::FullyConnected::Update ( float  learning_rate,
float  momentum,
float  adam_beta,
int  num_samples 
)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 298 of file fullyconnected.cpp.

299  {
300  weights_.Update(learning_rate, momentum, adam_beta, num_samples);
301 }
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)

Member Data Documentation

◆ acts_

NetworkIO tesseract::FullyConnected::acts_
protected

Definition at line 126 of file fullyconnected.h.

◆ external_source_

const TransposedArray* tesseract::FullyConnected::external_source_
protected

Definition at line 124 of file fullyconnected.h.

◆ int_mode_

bool tesseract::FullyConnected::int_mode_
protected

Definition at line 129 of file fullyconnected.h.

◆ source_t_

TransposedArray tesseract::FullyConnected::source_t_
protected

Definition at line 121 of file fullyconnected.h.

◆ weights_

WeightMatrix tesseract::FullyConnected::weights_
protected

Definition at line 119 of file fullyconnected.h.


The documentation for this class was generated from the following files: