tesseract  4.0.0-1-g2a2b
maxpool.cpp
Go to the documentation of this file.
1 // File: maxpool.cpp
3 // Description: Standard Max-Pooling layer.
4 // Author: Ray Smith
5 // Created: Tue Mar 18 16:28:18 PST 2014
6 //
7 // (C) Copyright 2014, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #include "maxpool.h"
20 #include "tprintf.h"
21 
22 namespace tesseract {
23 
24 Maxpool::Maxpool(const STRING& name, int ni, int x_scale, int y_scale)
25  : Reconfig(name, ni, x_scale, y_scale) {
26  type_ = NT_MAXPOOL;
27  no_ = ni;
28 }
29 
30 // Reads from the given file. Returns false in case of error.
32  bool result = Reconfig::DeSerialize(fp);
33  no_ = ni_;
34  return result;
35 }
36 
37 // Runs forward propagation of activations on the input line.
38 // See NetworkCpp for a detailed discussion of the arguments.
39 void Maxpool::Forward(bool debug, const NetworkIO& input,
40  const TransposedArray* input_transpose,
41  NetworkScratch* scratch, NetworkIO* output) {
42  output->ResizeScaled(input, x_scale_, y_scale_, no_);
43  maxes_.ResizeNoInit(output->Width(), ni_);
44  back_map_ = input.stride_map();
45 
46  StrideMap::Index dest_index(output->stride_map());
47  do {
48  int out_t = dest_index.t();
49  StrideMap::Index src_index(input.stride_map(), dest_index.index(FD_BATCH),
50  dest_index.index(FD_HEIGHT) * y_scale_,
51  dest_index.index(FD_WIDTH) * x_scale_);
52  // Find the max input out of x_scale_ groups of y_scale_ inputs.
53  // Do it independently for each input dimension.
54  int* max_line = maxes_[out_t];
55  int in_t = src_index.t();
56  output->CopyTimeStepFrom(out_t, input, in_t);
57  for (int i = 0; i < ni_; ++i) {
58  max_line[i] = in_t;
59  }
60  for (int x = 0; x < x_scale_; ++x) {
61  for (int y = 0; y < y_scale_; ++y) {
62  StrideMap::Index src_xy(src_index);
63  if (src_xy.AddOffset(x, FD_WIDTH) && src_xy.AddOffset(y, FD_HEIGHT)) {
64  output->MaxpoolTimeStep(out_t, input, src_xy.t(), max_line);
65  }
66  }
67  }
68  } while (dest_index.Increment());
69 }
70 
71 // Runs backward propagation of errors on the deltas line.
72 // See NetworkCpp for a detailed discussion of the arguments.
73 bool Maxpool::Backward(bool debug, const NetworkIO& fwd_deltas,
74  NetworkScratch* scratch,
75  NetworkIO* back_deltas) {
76  back_deltas->ResizeToMap(fwd_deltas.int_mode(), back_map_, ni_);
77  back_deltas->MaxpoolBackward(fwd_deltas, maxes_);
78  return true;
79 }
80 
81 
82 } // namespace tesseract.
int32_t y_scale_
Definition: reconfig.h:79
int32_t x_scale_
Definition: reconfig.h:78
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: maxpool.cpp:39
void MaxpoolBackward(const NetworkIO &fwd, const GENERIC_2D_ARRAY< int > &maxes)
Definition: networkio.cpp:700
bool AddOffset(int offset, FlexDimensions dimension)
Definition: stridemap.cpp:63
Maxpool(const STRING &name, int ni, int x_scale, int y_scale)
Definition: maxpool.cpp:24
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
Definition: networkio.cpp:388
void MaxpoolTimeStep(int dest_t, const NetworkIO &src, int src_t, int *max_line)
Definition: networkio.cpp:673
NetworkType type_
Definition: network.h:299
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:91
void ResizeScaled(const NetworkIO &src, int x_scale, int y_scale, int num_features)
Definition: networkio.cpp:67
bool DeSerialize(TFile *fp) override
Definition: reconfig.cpp:58
const StrideMap & stride_map() const
Definition: networkio.h:133
bool DeSerialize(TFile *fp) override
Definition: maxpool.cpp:31
Definition: strngs.h:45
bool int_mode() const
Definition: networkio.h:127
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
Definition: networkio.cpp:51
int Width() const
Definition: networkio.h:107
StrideMap back_map_
Definition: reconfig.h:76
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: maxpool.cpp:73