tesseract  5.0.0-alpha-619-ge9db
input.cpp
Go to the documentation of this file.
1 // File: input.cpp
3 // Description: Input layer class for neural network implementations.
4 // Author: Ray Smith
5 //
6 // (C) Copyright 2014, Google Inc.
7 // Licensed under the Apache License, Version 2.0 (the "License");
8 // you may not use this file except in compliance with the License.
9 // You may obtain a copy of the License at
10 // http://www.apache.org/licenses/LICENSE-2.0
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
17 
18 #include "input.h"
19 
20 #include "allheaders.h"
21 #include "imagedata.h"
22 #include "pageres.h"
23 #include "scrollview.h"
24 
25 namespace tesseract {
26 
27 // Max height for variable height inputs before scaling anyway.
28 const int kMaxInputHeight = 48;
29 
30 Input::Input(const STRING& name, int ni, int no)
31  : Network(NT_INPUT, name, ni, no), cached_x_scale_(1) {}
32 Input::Input(const STRING& name, const StaticShape& shape)
33  : Network(NT_INPUT, name, shape.height(), shape.depth()),
34  shape_(shape),
35  cached_x_scale_(1) {
36  if (shape.height() == 1) ni_ = shape.depth();
37 }
38 
39 // Writes to the given file. Returns false in case of error.
40 bool Input::Serialize(TFile* fp) const {
41  return Network::Serialize(fp) && shape_.Serialize(fp);
42 }
43 
44 // Reads from the given file. Returns false in case of error.
46  return shape_.DeSerialize(fp);
47 }
48 
49 // Returns an integer reduction factor that the network applies to the
50 // time sequence. Assumes that any 2-d is already eliminated. Used for
51 // scaling bounding boxes of truth data.
52 int Input::XScaleFactor() const {
53  return 1;
54 }
55 
56 // Provides the (minimum) x scale factor to the network (of interest only to
57 // input units) so they can determine how to scale bounding boxes.
58 void Input::CacheXScaleFactor(int factor) {
59  cached_x_scale_ = factor;
60 }
61 
62 // Runs forward propagation of activations on the input line.
63 // See Network for a detailed discussion of the arguments.
64 void Input::Forward(bool debug, const NetworkIO& input,
65  const TransposedArray* input_transpose,
66  NetworkScratch* scratch, NetworkIO* output) {
67  *output = input;
68 }
69 
70 // Runs backward propagation of errors on the deltas line.
71 // See NetworkCpp for a detailed discussion of the arguments.
72 bool Input::Backward(bool debug, const NetworkIO& fwd_deltas,
73  NetworkScratch* scratch,
74  NetworkIO* back_deltas) {
75  tprintf("Input::Backward should not be called!!\n");
76  return false;
77 }
78 
79 // Creates and returns a Pix of appropriate size for the network from the
80 // image_data. If non-null, *image_scale returns the image scale factor used.
81 // Returns nullptr on error.
82 /* static */
83 Pix* Input::PrepareLSTMInputs(const ImageData& image_data,
84  const Network* network, int min_width,
85  TRand* randomizer, float* image_scale) {
86  // Note that NumInputs() is defined as input image height.
87  int target_height = network->NumInputs();
88  int width, height;
89  Pix* pix = image_data.PreScale(target_height, kMaxInputHeight, image_scale,
90  &width, &height, nullptr);
91  if (pix == nullptr) {
92  tprintf("Bad pix from ImageData!\n");
93  return nullptr;
94  }
95  if (width < min_width || height < min_width) {
96  tprintf("Image too small to scale!! (%dx%d vs min width of %d)\n", width,
97  height, min_width);
98  pixDestroy(&pix);
99  return nullptr;
100  }
101  return pix;
102 }
103 
104 // Converts the given pix to a NetworkIO of height and depth appropriate to the
105 // given StaticShape:
106 // If depth == 3, convert to 24 bit color, otherwise normalized grey.
107 // Scale to target height, if the shape's height is > 1, or its depth if the
108 // height == 1. If height == 0 then no scaling.
109 // NOTE: It isn't safe for multiple threads to call this on the same pix.
110 /* static */
111 void Input::PreparePixInput(const StaticShape& shape, const Pix* pix,
112  TRand* randomizer, NetworkIO* input) {
113  bool color = shape.depth() == 3;
114  Pix* var_pix = const_cast<Pix*>(pix);
115  int depth = pixGetDepth(var_pix);
116  Pix* normed_pix = nullptr;
117  // On input to BaseAPI, an image is forced to be 1, 8 or 24 bit, without
118  // colormap, so we just have to deal with depth conversion here.
119  if (color) {
120  // Force RGB.
121  if (depth == 32)
122  normed_pix = pixClone(var_pix);
123  else
124  normed_pix = pixConvertTo32(var_pix);
125  } else {
126  // Convert non-8-bit images to 8 bit.
127  if (depth == 8)
128  normed_pix = pixClone(var_pix);
129  else
130  normed_pix = pixConvertTo8(var_pix, false);
131  }
132  int height = pixGetHeight(normed_pix);
133  int target_height = shape.height();
134  if (target_height == 1) target_height = shape.depth();
135  if (target_height != 0 && target_height != height) {
136  // Get the scaled image.
137  float im_factor = static_cast<float>(target_height) / height;
138  Pix* scaled_pix = pixScale(normed_pix, im_factor, im_factor);
139  pixDestroy(&normed_pix);
140  normed_pix = scaled_pix;
141  }
142  input->FromPix(shape, normed_pix, randomizer);
143  pixDestroy(&normed_pix);
144 }
145 
146 } // namespace tesseract.
tesseract::StaticShape
Definition: static_shape.h:38
pageres.h
tesseract::ImageData::PreScale
Pix * PreScale(int target_height, int max_height, float *scale_factor, int *scaled_width, int *scaled_height, GenericVector< TBOX > *boxes) const
Definition: imagedata.cpp:224
STRING
Definition: strngs.h:45
tesseract::Input::DeSerialize
bool DeSerialize(TFile *fp) override
Definition: input.cpp:45
tesseract::NetworkScratch
Definition: networkscratch.h:34
tesseract::Input::PreparePixInput
static void PreparePixInput(const StaticShape &shape, const Pix *pix, TRand *randomizer, NetworkIO *input)
Definition: input.cpp:111
tesseract::ImageData
Definition: imagedata.h:104
tesseract::Input::PrepareLSTMInputs
static Pix * PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
Definition: input.cpp:83
tesseract::Input::Forward
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
Definition: input.cpp:64
tesseract::StaticShape::DeSerialize
bool DeSerialize(TFile *fp)
Definition: static_shape.h:64
tesseract::kMaxInputHeight
const int kMaxInputHeight
Definition: input.cpp:28
tesseract::Input::XScaleFactor
int XScaleFactor() const override
Definition: input.cpp:52
tesseract::StaticShape::depth
int depth() const
Definition: static_shape.h:48
tesseract::TFile
Definition: serialis.h:75
tesseract::NetworkIO
Definition: networkio.h:39
tesseract::NetworkIO::FromPix
void FromPix(const StaticShape &shape, const Pix *pix, TRand *randomizer)
Definition: networkio.cpp:161
tesseract
Definition: baseapi.h:65
tesseract::StaticShape::Serialize
bool Serialize(TFile *fp) const
Definition: static_shape.h:76
tesseract::NT_INPUT
Definition: network.h:45
tesseract::StaticShape::height
int height() const
Definition: static_shape.h:44
tesseract::Network
Definition: network.h:105
tesseract::Input::Input
Input(const STRING &name, int ni, int no)
Definition: input.cpp:30
imagedata.h
tesseract::Input::Backward
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
Definition: input.cpp:72
tesseract::TransposedArray
Definition: weightmatrix.h:32
tesseract::Input::CacheXScaleFactor
void CacheXScaleFactor(int factor) override
Definition: input.cpp:58
tesseract::Input::Serialize
bool Serialize(TFile *fp) const override
Definition: input.cpp:40
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
tesseract::Network::Serialize
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
tesseract::Network::NumInputs
int NumInputs() const
Definition: network.h:120
tesseract::Network::ni_
int32_t ni_
Definition: network.h:297
tesseract::TRand
Definition: helpers.h:50
scrollview.h
input.h