tesseract  5.0.0-alpha-619-ge9db
maxpool.cpp
Go to the documentation of this file.
1 // File: maxpool.cpp
3 // Description: Standard Max-Pooling layer.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2014, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #include "maxpool.h"
19 
20 namespace tesseract {
21 
22 Maxpool::Maxpool(const STRING& name, int ni, int x_scale, int y_scale)
23  : Reconfig(name, ni, x_scale, y_scale) {
24  type_ = NT_MAXPOOL;
25  no_ = ni;
26 }
27 
28 // Reads from the given file. Returns false in case of error.
30  bool result = Reconfig::DeSerialize(fp);
31  no_ = ni_;
32  return result;
33 }
34 
35 // Runs forward propagation of activations on the input line.
36 // See NetworkCpp for a detailed discussion of the arguments.
37 void Maxpool::Forward(bool debug, const NetworkIO& input,
38  const TransposedArray* input_transpose,
39  NetworkScratch* scratch, NetworkIO* output) {
40  output->ResizeScaled(input, x_scale_, y_scale_, no_);
41  maxes_.ResizeNoInit(output->Width(), ni_);
42  back_map_ = input.stride_map();
43 
44  StrideMap::Index dest_index(output->stride_map());
45  do {
46  int out_t = dest_index.t();
47  StrideMap::Index src_index(input.stride_map(), dest_index.index(FD_BATCH),
48  dest_index.index(FD_HEIGHT) * y_scale_,
49  dest_index.index(FD_WIDTH) * x_scale_);
50  // Find the max input out of x_scale_ groups of y_scale_ inputs.
51  // Do it independently for each input dimension.
52  int* max_line = maxes_[out_t];
53  int in_t = src_index.t();
54  output->CopyTimeStepFrom(out_t, input, in_t);
55  for (int i = 0; i < ni_; ++i) {
56  max_line[i] = in_t;
57  }
58  for (int x = 0; x < x_scale_; ++x) {
59  for (int y = 0; y < y_scale_; ++y) {
60  StrideMap::Index src_xy(src_index);
61  if (src_xy.AddOffset(x, FD_WIDTH) && src_xy.AddOffset(y, FD_HEIGHT)) {
62  output->MaxpoolTimeStep(out_t, input, src_xy.t(), max_line);
63  }
64  }
65  }
66  } while (dest_index.Increment());
67 }
68 
69 // Runs backward propagation of errors on the deltas line.
70 // See NetworkCpp for a detailed discussion of the arguments.
71 bool Maxpool::Backward(bool debug, const NetworkIO& fwd_deltas,
72  NetworkScratch* scratch,
73  NetworkIO* back_deltas) {
74  back_deltas->ResizeToMap(fwd_deltas.int_mode(), back_map_, ni_);
75  back_deltas->MaxpoolBackward(fwd_deltas, maxes_);
76  return true;
77 }
78 
79 
80 } // namespace tesseract.
tesseract::Reconfig::DeSerialize
bool DeSerialize(TFile *fp) override
Definition: reconfig.cpp:57
tesseract::NetworkIO::int_mode
bool int_mode() const
Definition: networkio.h:127
tesseract::NetworkIO::ResizeScaled
void ResizeScaled(const NetworkIO &src, int x_scale, int y_scale, int num_features)
Definition: networkio.cpp:62
tesseract::StrideMap::Index
Definition: stridemap.h:44
STRING
Definition: strngs.h:45
tesseract::NetworkIO::Width
int Width() const
Definition: networkio.h:107
tesseract::Reconfig
Definition: reconfig.h:32
tesseract::NetworkScratch
Definition: networkscratch.h:34
tesseract::NetworkIO::stride_map
const StrideMap & stride_map() const
Definition: networkio.h:133
tesseract::Maxpool::Forward
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: maxpool.cpp:37
tesseract::Maxpool::Maxpool
Maxpool(const STRING &name, int ni, int x_scale, int y_scale)
Definition: maxpool.cpp:22
tesseract::Reconfig::y_scale_
int32_t y_scale_
Definition: reconfig.h:83
maxpool.h
tesseract::Network::type_
NetworkType type_
Definition: network.h:293
tesseract::StrideMap::Index::AddOffset
bool AddOffset(int offset, FlexDimensions dimension)
Definition: stridemap.cpp:62
tesseract::StrideMap::Index::t
int t() const
Definition: stridemap.h:57
tesseract::FD_WIDTH
Definition: stridemap.h:35
GENERIC_2D_ARRAY::ResizeNoInit
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:90
tesseract::TFile
Definition: serialis.h:75
tesseract::NetworkIO::MaxpoolBackward
void MaxpoolBackward(const NetworkIO &fwd, const GENERIC_2D_ARRAY< int > &maxes)
Definition: networkio.cpp:695
tesseract::NetworkIO::CopyTimeStepFrom
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
Definition: networkio.cpp:383
tesseract::NetworkIO
Definition: networkio.h:39
tesseract
Definition: baseapi.h:65
tesseract::Reconfig::x_scale_
int32_t x_scale_
Definition: reconfig.h:82
tesseract::NetworkIO::ResizeToMap
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
Definition: networkio.cpp:46
tesseract::FD_HEIGHT
Definition: stridemap.h:34
tesseract::TransposedArray
Definition: weightmatrix.h:32
tesseract::NetworkIO::MaxpoolTimeStep
void MaxpoolTimeStep(int dest_t, const NetworkIO &src, int src_t, int *max_line)
Definition: networkio.cpp:668
tesseract::Reconfig::back_map_
StrideMap back_map_
Definition: reconfig.h:80
tesseract::FD_BATCH
Definition: stridemap.h:33
tesseract::Maxpool::DeSerialize
bool DeSerialize(TFile *fp) override
Definition: maxpool.cpp:29
tesseract::Network::no_
int32_t no_
Definition: network.h:298
tesseract::Network::ni_
int32_t ni_
Definition: network.h:297
tesseract::NT_MAXPOOL
Definition: network.h:48
tesseract::Maxpool::Backward
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: maxpool.cpp:71