tesseract  4.0.0-1-g2a2b
input.cpp
Go to the documentation of this file.
1 // File: input.cpp
3 // Description: Input layer class for neural network implementations.
4 // Author: Ray Smith
5 // Created: Thu Mar 13 09:10:34 PDT 2014
6 //
7 // (C) Copyright 2014, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #include "input.h"
20 
21 #include "allheaders.h"
22 #include "imagedata.h"
23 #include "pageres.h"
24 #include "scrollview.h"
25 
26 namespace tesseract {
27 
28 // Max height for variable height inputs before scaling anyway.
29 const int kMaxInputHeight = 48;
30 
31 Input::Input(const STRING& name, int ni, int no)
32  : Network(NT_INPUT, name, ni, no), cached_x_scale_(1) {}
33 Input::Input(const STRING& name, const StaticShape& shape)
34  : Network(NT_INPUT, name, shape.height(), shape.depth()),
35  shape_(shape),
36  cached_x_scale_(1) {
37  if (shape.height() == 1) ni_ = shape.depth();
38 }
39 
40 // Writes to the given file. Returns false in case of error.
41 bool Input::Serialize(TFile* fp) const {
42  return Network::Serialize(fp) && shape_.Serialize(fp);
43 }
44 
45 // Reads from the given file. Returns false in case of error.
47  return shape_.DeSerialize(fp);
48 }
49 
50 // Returns an integer reduction factor that the network applies to the
51 // time sequence. Assumes that any 2-d is already eliminated. Used for
52 // scaling bounding boxes of truth data.
53 int Input::XScaleFactor() const {
54  return 1;
55 }
56 
57 // Provides the (minimum) x scale factor to the network (of interest only to
58 // input units) so they can determine how to scale bounding boxes.
59 void Input::CacheXScaleFactor(int factor) {
60  cached_x_scale_ = factor;
61 }
62 
63 // Runs forward propagation of activations on the input line.
64 // See Network for a detailed discussion of the arguments.
65 void Input::Forward(bool debug, const NetworkIO& input,
66  const TransposedArray* input_transpose,
67  NetworkScratch* scratch, NetworkIO* output) {
68  *output = input;
69 }
70 
71 // Runs backward propagation of errors on the deltas line.
72 // See NetworkCpp for a detailed discussion of the arguments.
73 bool Input::Backward(bool debug, const NetworkIO& fwd_deltas,
74  NetworkScratch* scratch,
75  NetworkIO* back_deltas) {
76  tprintf("Input::Backward should not be called!!\n");
77  return false;
78 }
79 
80 // Creates and returns a Pix of appropriate size for the network from the
81 // image_data. If non-null, *image_scale returns the image scale factor used.
82 // Returns nullptr on error.
83 /* static */
84 Pix* Input::PrepareLSTMInputs(const ImageData& image_data,
85  const Network* network, int min_width,
86  TRand* randomizer, float* image_scale) {
87  // Note that NumInputs() is defined as input image height.
88  int target_height = network->NumInputs();
89  int width, height;
90  Pix* pix = image_data.PreScale(target_height, kMaxInputHeight, image_scale,
91  &width, &height, nullptr);
92  if (pix == nullptr) {
93  tprintf("Bad pix from ImageData!\n");
94  return nullptr;
95  }
96  if (width <= min_width || height < min_width) {
97  tprintf("Image too small to scale!! (%dx%d vs min width of %d)\n", width,
98  height, min_width);
99  pixDestroy(&pix);
100  return nullptr;
101  }
102  return pix;
103 }
104 
105 // Converts the given pix to a NetworkIO of height and depth appropriate to the
106 // given StaticShape:
107 // If depth == 3, convert to 24 bit color, otherwise normalized grey.
108 // Scale to target height, if the shape's height is > 1, or its depth if the
109 // height == 1. If height == 0 then no scaling.
110 // NOTE: It isn't safe for multiple threads to call this on the same pix.
111 /* static */
112 void Input::PreparePixInput(const StaticShape& shape, const Pix* pix,
113  TRand* randomizer, NetworkIO* input) {
114  bool color = shape.depth() == 3;
115  Pix* var_pix = const_cast<Pix*>(pix);
116  int depth = pixGetDepth(var_pix);
117  Pix* normed_pix = nullptr;
118  // On input to BaseAPI, an image is forced to be 1, 8 or 24 bit, without
119  // colormap, so we just have to deal with depth conversion here.
120  if (color) {
121  // Force RGB.
122  if (depth == 32)
123  normed_pix = pixClone(var_pix);
124  else
125  normed_pix = pixConvertTo32(var_pix);
126  } else {
127  // Convert non-8-bit images to 8 bit.
128  if (depth == 8)
129  normed_pix = pixClone(var_pix);
130  else
131  normed_pix = pixConvertTo8(var_pix, false);
132  }
133  int height = pixGetHeight(normed_pix);
134  int target_height = shape.height();
135  if (target_height == 1) target_height = shape.depth();
136  if (target_height != 0 && target_height != height) {
137  // Get the scaled image.
138  float im_factor = static_cast<float>(target_height) / height;
139  Pix* scaled_pix = pixScale(normed_pix, im_factor, im_factor);
140  pixDestroy(&normed_pix);
141  normed_pix = scaled_pix;
142  }
143  input->FromPix(shape, normed_pix, randomizer);
144  pixDestroy(&normed_pix);
145 }
146 
147 } // namespace tesseract.
bool Serialize(TFile *fp) const override
Definition: input.cpp:41
Input(const STRING &name, int ni, int no)
Definition: input.cpp:31
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: input.cpp:65
Pix * PreScale(int target_height, int max_height, float *scale_factor, int *scaled_width, int *scaled_height, GenericVector< TBOX > *boxes) const
Definition: imagedata.cpp:226
int NumInputs() const
Definition: network.h:120
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
bool Serialize(TFile *fp) const
Definition: static_shape.h:76
void CacheXScaleFactor(int factor) override
Definition: input.cpp:59
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:37
const int kMaxInputHeight
Definition: input.cpp:29
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: input.cpp:73
int XScaleFactor() const override
Definition: input.cpp:53
static void PreparePixInput(const StaticShape &shape, const Pix *pix, TRand *randomizer, NetworkIO *input)
Definition: input.cpp:112
Definition: strngs.h:45
void FromPix(const StaticShape &shape, const Pix *pix, TRand *randomizer)
Definition: networkio.cpp:166
bool DeSerialize(TFile *fp)
Definition: static_shape.h:64
bool DeSerialize(TFile *fp) override
Definition: input.cpp:46
static Pix * PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
Definition: input.cpp:84