tesseract  5.0.0-alpha-619-ge9db
tesseract::Input Class Reference

#include <input.h>

Inheritance diagram for tesseract::Input:
tesseract::Network

Public Member Functions

 Input (const STRING &name, int ni, int no)
 
 Input (const STRING &name, const StaticShape &shape)
 
 ~Input () override=default
 
STRING spec () const override
 
StaticShape InputShape () const override
 
StaticShape OutputShape (const StaticShape &input_shape) const override
 
bool Serialize (TFile *fp) const override
 
bool DeSerialize (TFile *fp) override
 
int XScaleFactor () const override
 
void CacheXScaleFactor (int factor) override
 
void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
 
bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()=default
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
const STRINGname () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetEnableTraining (TrainingState state)
 
virtual void SetNetworkFlags (uint32_t flags)
 
virtual int InitWeights (float range, TRand *randomizer)
 
virtual int RemapOutputs (int old_no, const std::vector< int > &code_map)
 
virtual void ConvertToInt ()
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual void Update (float learning_rate, float momentum, float adam_beta, int num_samples)
 
virtual void CountAlternators (const Network &other, double *same, double *changed) const
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Static Public Member Functions

static Pix * PrepareLSTMInputs (const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
 
static void PreparePixInput (const StaticShape &shape, const Pix *pix, TRand *randomizer, NetworkIO *input)
 
- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 

Additional Inherited Members

- Protected Member Functions inherited from tesseract::Network
double Random (double range)
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
int32_t network_flags_
 
int32_t ni_
 
int32_t no_
 
int32_t num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 

Detailed Description

Definition at line 27 of file input.h.

Constructor & Destructor Documentation

◆ Input() [1/2]

tesseract::Input::Input ( const STRING name,
int  ni,
int  no 
)

Definition at line 30 of file input.cpp.

31  : Network(NT_INPUT, name, ni, no), cached_x_scale_(1) {}

◆ Input() [2/2]

tesseract::Input::Input ( const STRING name,
const StaticShape shape 
)

Definition at line 32 of file input.cpp.

33  : Network(NT_INPUT, name, shape.height(), shape.depth()),
34  shape_(shape),
35  cached_x_scale_(1) {
36  if (shape.height() == 1) ni_ = shape.depth();
37 }

◆ ~Input()

tesseract::Input::~Input ( )
overridedefault

Member Function Documentation

◆ Backward()

bool tesseract::Input::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
overridevirtual

Implements tesseract::Network.

Definition at line 72 of file input.cpp.

74  {
75  tprintf("Input::Backward should not be called!!\n");
76  return false;
77 }

◆ CacheXScaleFactor()

void tesseract::Input::CacheXScaleFactor ( int  factor)
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 58 of file input.cpp.

58  {
59  cached_x_scale_ = factor;
60 }

◆ DeSerialize()

bool tesseract::Input::DeSerialize ( TFile fp)
overridevirtual

Implements tesseract::Network.

Definition at line 45 of file input.cpp.

45  {
46  return shape_.DeSerialize(fp);
47 }

◆ Forward()

void tesseract::Input::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
overridevirtual

Implements tesseract::Network.

Definition at line 64 of file input.cpp.

66  {
67  *output = input;
68 }

◆ InputShape()

StaticShape tesseract::Input::InputShape ( ) const
inlineoverridevirtual

Reimplemented from tesseract::Network.

Definition at line 43 of file input.h.

43 { return shape_; }

◆ OutputShape()

StaticShape tesseract::Input::OutputShape ( const StaticShape input_shape) const
inlineoverridevirtual

Reimplemented from tesseract::Network.

Definition at line 46 of file input.h.

46  {
47  return shape_;
48  }

◆ PrepareLSTMInputs()

Pix * tesseract::Input::PrepareLSTMInputs ( const ImageData image_data,
const Network network,
int  min_width,
TRand randomizer,
float *  image_scale 
)
static

Definition at line 83 of file input.cpp.

85  {
86  // Note that NumInputs() is defined as input image height.
87  int target_height = network->NumInputs();
88  int width, height;
89  Pix* pix = image_data.PreScale(target_height, kMaxInputHeight, image_scale,
90  &width, &height, nullptr);
91  if (pix == nullptr) {
92  tprintf("Bad pix from ImageData!\n");
93  return nullptr;
94  }
95  if (width < min_width || height < min_width) {
96  tprintf("Image too small to scale!! (%dx%d vs min width of %d)\n", width,
97  height, min_width);
98  pixDestroy(&pix);
99  return nullptr;
100  }
101  return pix;
102 }

◆ PreparePixInput()

void tesseract::Input::PreparePixInput ( const StaticShape shape,
const Pix *  pix,
TRand randomizer,
NetworkIO input 
)
static

Definition at line 111 of file input.cpp.

112  {
113  bool color = shape.depth() == 3;
114  Pix* var_pix = const_cast<Pix*>(pix);
115  int depth = pixGetDepth(var_pix);
116  Pix* normed_pix = nullptr;
117  // On input to BaseAPI, an image is forced to be 1, 8 or 24 bit, without
118  // colormap, so we just have to deal with depth conversion here.
119  if (color) {
120  // Force RGB.
121  if (depth == 32)
122  normed_pix = pixClone(var_pix);
123  else
124  normed_pix = pixConvertTo32(var_pix);
125  } else {
126  // Convert non-8-bit images to 8 bit.
127  if (depth == 8)
128  normed_pix = pixClone(var_pix);
129  else
130  normed_pix = pixConvertTo8(var_pix, false);
131  }
132  int height = pixGetHeight(normed_pix);
133  int target_height = shape.height();
134  if (target_height == 1) target_height = shape.depth();
135  if (target_height != 0 && target_height != height) {
136  // Get the scaled image.
137  float im_factor = static_cast<float>(target_height) / height;
138  Pix* scaled_pix = pixScale(normed_pix, im_factor, im_factor);
139  pixDestroy(&normed_pix);
140  normed_pix = scaled_pix;
141  }
142  input->FromPix(shape, normed_pix, randomizer);
143  pixDestroy(&normed_pix);
144 }

◆ Serialize()

bool tesseract::Input::Serialize ( TFile fp) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 40 of file input.cpp.

40  {
41  return Network::Serialize(fp) && shape_.Serialize(fp);
42 }

◆ spec()

STRING tesseract::Input::spec ( ) const
inlineoverridevirtual

Reimplemented from tesseract::Network.

Definition at line 33 of file input.h.

33  {
34  STRING spec;
35  spec.add_str_int("", shape_.batch());
36  spec.add_str_int(",", shape_.height());
37  spec.add_str_int(",", shape_.width());
38  spec.add_str_int(",", shape_.depth());
39  return spec;
40  }

◆ XScaleFactor()

int tesseract::Input::XScaleFactor ( ) const
overridevirtual

Reimplemented from tesseract::Network.

Definition at line 52 of file input.cpp.

52  {
53  return 1;
54 }

The documentation for this class was generated from the following files:
STRING::add_str_int
void add_str_int(const char *str, int number)
Definition: strngs.cpp:370
tesseract::StaticShape::batch
int batch() const
Definition: static_shape.h:42
STRING
Definition: strngs.h:45
tesseract::StaticShape::DeSerialize
bool DeSerialize(TFile *fp)
Definition: static_shape.h:64
tesseract::kMaxInputHeight
const int kMaxInputHeight
Definition: input.cpp:28
tesseract::StaticShape::depth
int depth() const
Definition: static_shape.h:48
tesseract::Input::spec
STRING spec() const override
Definition: input.h:33
tesseract::StaticShape::width
int width() const
Definition: static_shape.h:46
tesseract::StaticShape::Serialize
bool Serialize(TFile *fp) const
Definition: static_shape.h:76
tesseract::NT_INPUT
Definition: network.h:45
tesseract::StaticShape::height
int height() const
Definition: static_shape.h:44
tesseract::Network::name
const STRING & name() const
Definition: network.h:138
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
tesseract::Network::Serialize
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:151
tesseract::Network::ni_
int32_t ni_
Definition: network.h:297
tesseract::Network::Network
Network()
Definition: network.cpp:76