tesseract  4.0.0-1-g2a2b
static_shape.h
Go to the documentation of this file.
1 // File: static_shape.h
3 // Description: Defines the size of the 4-d tensor input/output from a network.
4 // Author: Ray Smith
5 // Created: Fri Oct 14 09:07:31 PST 2016
6 //
7 // (C) Copyright 2016, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_STATIC_SHAPE_H_
20 #define TESSERACT_LSTM_STATIC_SHAPE_H_
21 
22 #include "serialis.h" // for TFile
23 #include "tprintf.h" // for tprintf
24 
25 namespace tesseract {
26 
27 // Enum describing the loss function to apply during training and/or the
28 // decoding method to apply at runtime.
29 enum LossType {
30  LT_NONE, // Undefined.
31  LT_CTC, // Softmax with standard CTC for training/decoding.
32  LT_SOFTMAX, // Outputs sum to 1 in fixed positions.
33  LT_LOGISTIC, // Logistic outputs with independent values.
34 };
35 
36 // Simple class to hold the tensor shape that is known at network build time
37 // and the LossType of the loss function.
38 class StaticShape {
39  public:
41  : batch_(0), height_(0), width_(0), depth_(0), loss_type_(LT_NONE) {}
42  int batch() const { return batch_; }
43  void set_batch(int value) { batch_ = value; }
44  int height() const { return height_; }
45  void set_height(int value) { height_ = value; }
46  int width() const { return width_; }
47  void set_width(int value) { width_ = value; }
48  int depth() const { return depth_; }
49  void set_depth(int value) { depth_ = value; }
50  LossType loss_type() const { return loss_type_; }
51  void set_loss_type(LossType value) { loss_type_ = value; }
52  void SetShape(int batch, int height, int width, int depth) {
53  batch_ = batch;
54  height_ = height;
55  width_ = width;
56  depth_ = depth;
57  }
58 
59  void Print() const {
60  tprintf("Batch=%d, Height=%d, Width=%d, Depth=%d, loss=%d\n", batch_,
61  height_, width_, depth_, loss_type_);
62  }
63 
64  bool DeSerialize(TFile *fp) {
65  int32_t tmp = LT_NONE;
66  bool result =
67  fp->DeSerialize(&batch_) &&
68  fp->DeSerialize(&height_) &&
69  fp->DeSerialize(&width_) &&
70  fp->DeSerialize(&depth_) &&
71  fp->DeSerialize(&tmp);
72  loss_type_ = static_cast<LossType>(tmp);
73  return result;
74  }
75 
76  bool Serialize(TFile *fp) const {
77  int32_t tmp = loss_type_;
78  return
79  fp->Serialize(&batch_) &&
80  fp->Serialize(&height_) &&
81  fp->Serialize(&width_) &&
82  fp->Serialize(&depth_) &&
83  fp->Serialize(&tmp);
84  }
85 
86  private:
87  // Size of the 4-D tensor input/output to a network. A value of zero is
88  // allowed for all except depth_ and means to be determined at runtime, and
89  // regarded as variable.
90  // Number of elements in a batch, or number of frames in a video stream.
91  int32_t batch_;
92  // Height of the image.
93  int32_t height_;
94  // Width of the image.
95  int32_t width_;
96  // Depth of the image. (Number of "nodes").
97  int32_t depth_;
98  // How to train/interpret the output.
99  LossType loss_type_;
100 };
101 
102 } // namespace tesseract
103 
104 #endif // TESSERACT_LSTM_STATIC_SHAPE_H_
bool DeSerialize(char *data, size_t count=1)
Definition: serialis.cpp:103
void set_batch(int value)
Definition: static_shape.h:43
LossType loss_type() const
Definition: static_shape.h:50
void set_width(int value)
Definition: static_shape.h:47
bool Serialize(TFile *fp) const
Definition: static_shape.h:76
bool Serialize(const char *data, size_t count=1)
Definition: serialis.cpp:147
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:37
void set_loss_type(LossType value)
Definition: static_shape.h:51
bool DeSerialize(TFile *fp)
Definition: static_shape.h:64
void set_depth(int value)
Definition: static_shape.h:49
void set_height(int value)
Definition: static_shape.h:45
void SetShape(int batch, int height, int width, int depth)
Definition: static_shape.h:52