#include <network.h>
|
| | Network () |
| |
| | Network (NetworkType type, const STRING &name, int ni, int no) |
| |
| virtual | ~Network ()=default |
| |
| NetworkType | type () const |
| |
| bool | IsTraining () const |
| |
| bool | needs_to_backprop () const |
| |
| int | num_weights () const |
| |
| int | NumInputs () const |
| |
| int | NumOutputs () const |
| |
| virtual StaticShape | InputShape () const |
| |
| virtual StaticShape | OutputShape (const StaticShape &input_shape) const |
| |
| const STRING & | name () const |
| |
| virtual STRING | spec () const |
| |
| bool | TestFlag (NetworkFlags flag) const |
| |
| virtual bool | IsPlumbingType () const |
| |
| virtual void | SetEnableTraining (TrainingState state) |
| |
| virtual void | SetNetworkFlags (uint32_t flags) |
| |
| virtual int | InitWeights (float range, TRand *randomizer) |
| |
| virtual int | RemapOutputs (int old_no, const std::vector< int > &code_map) |
| |
| virtual void | ConvertToInt () |
| |
| virtual void | SetRandomizer (TRand *randomizer) |
| |
| virtual bool | SetupNeedsBackprop (bool needs_backprop) |
| |
| virtual int | XScaleFactor () const |
| |
| virtual void | CacheXScaleFactor (int factor) |
| |
| virtual void | DebugWeights ()=0 |
| |
| virtual bool | Serialize (TFile *fp) const |
| |
| virtual bool | DeSerialize (TFile *fp)=0 |
| |
| virtual void | Update (float learning_rate, float momentum, float adam_beta, int num_samples) |
| |
| virtual void | CountAlternators (const Network &other, double *same, double *changed) const |
| |
| virtual void | Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0 |
| |
| virtual bool | Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0 |
| |
| void | DisplayForward (const NetworkIO &matrix) |
| |
| void | DisplayBackward (const NetworkIO &matrix) |
| |
Definition at line 105 of file network.h.
◆ Network() [1/2]
| tesseract::Network::Network |
( |
| ) |
|
◆ Network() [2/2]
| tesseract::Network::Network |
( |
NetworkType |
type, |
|
|
const STRING & |
name, |
|
|
int |
ni, |
|
|
int |
no |
|
) |
| |
◆ ~Network()
| virtual tesseract::Network::~Network |
( |
| ) |
|
|
virtualdefault |
◆ Backward()
◆ CacheXScaleFactor()
| virtual void tesseract::Network::CacheXScaleFactor |
( |
int |
factor | ) |
|
|
inlinevirtual |
◆ ClearWindow()
| void tesseract::Network::ClearWindow |
( |
bool |
tess_coords, |
|
|
const char * |
window_name, |
|
|
int |
width, |
|
|
int |
height, |
|
|
ScrollView ** |
window |
|
) |
| |
|
static |
Definition at line 312 of file network.cpp.
314 if (*window ==
nullptr) {
315 int min_size = std::min(width, height);
317 if (min_size < 1) min_size = 1;
325 *window =
new ScrollView(window_name, 80, 100, width, height, width, height,
327 tprintf(
"Created window %s of size %d, %d\n", window_name, width, height);
◆ ConvertToInt()
| virtual void tesseract::Network::ConvertToInt |
( |
| ) |
|
|
inlinevirtual |
◆ CountAlternators()
| virtual void tesseract::Network::CountAlternators |
( |
const Network & |
other, |
|
|
double * |
same, |
|
|
double * |
changed |
|
) |
| const |
|
inlinevirtual |
◆ CreateFromFile()
Definition at line 187 of file network.cpp.
191 int32_t network_flags;
198 type = getNetworkType(fp);
199 if (!fp->DeSerialize(&data))
return nullptr;
201 if (!fp->DeSerialize(&data))
return nullptr;
203 if (!fp->DeSerialize(&network_flags))
return nullptr;
204 if (!fp->DeSerialize(&ni))
return nullptr;
205 if (!fp->DeSerialize(&no))
return nullptr;
206 if (!fp->DeSerialize(&
num_weights))
return nullptr;
211 network =
new Convolve(
name, ni, 0, 0);
214 network =
new Input(
name, ni, no);
221 new LSTM(
name, ni, no, no,
false,
type);
224 network =
new Maxpool(
name, ni, 0, 0);
235 network =
new Reconfig(
name, ni, 0, 0);
244 network =
new Series(
name);
247 #ifdef INCLUDE_TENSORFLOW
248 network =
new TFNetwork(
name);
250 tprintf(
"TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
262 network =
new FullyConnected(
name, ni, no,
type);
268 network->training_ = training;
270 network->network_flags_ = network_flags;
272 if (!network->DeSerialize(fp)) {
◆ DebugWeights()
| virtual void tesseract::Network::DebugWeights |
( |
| ) |
|
|
pure virtual |
◆ DeSerialize()
| virtual bool tesseract::Network::DeSerialize |
( |
TFile * |
fp | ) |
|
|
pure virtual |
◆ DisplayBackward()
| void tesseract::Network::DisplayBackward |
( |
const NetworkIO & |
matrix | ) |
|
Definition at line 299 of file network.cpp.
300 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
301 Pix* image = matrix.ToPix();
307 #endif // GRAPHICS_DISABLED
◆ DisplayForward()
| void tesseract::Network::DisplayForward |
( |
const NetworkIO & |
matrix | ) |
|
Definition at line 288 of file network.cpp.
289 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
290 Pix* image = matrix.ToPix();
295 #endif // GRAPHICS_DISABLED
◆ DisplayImage()
| int tesseract::Network::DisplayImage |
( |
Pix * |
pix, |
|
|
ScrollView * |
window |
|
) |
| |
|
static |
Definition at line 335 of file network.cpp.
336 int height = pixGetHeight(pix);
337 window->
Image(pix, 0, 0);
◆ Forward()
◆ InitWeights()
| int tesseract::Network::InitWeights |
( |
float |
range, |
|
|
TRand * |
randomizer |
|
) |
| |
|
virtual |
◆ InputShape()
| virtual StaticShape tesseract::Network::InputShape |
( |
| ) |
const |
|
inlinevirtual |
◆ IsPlumbingType()
| virtual bool tesseract::Network::IsPlumbingType |
( |
| ) |
const |
|
inlinevirtual |
◆ IsTraining()
| bool tesseract::Network::IsTraining |
( |
| ) |
const |
|
inline |
◆ name()
| const STRING& tesseract::Network::name |
( |
| ) |
const |
|
inline |
◆ needs_to_backprop()
| bool tesseract::Network::needs_to_backprop |
( |
| ) |
const |
|
inline |
◆ num_weights()
| int tesseract::Network::num_weights |
( |
| ) |
const |
|
inline |
◆ NumInputs()
| int tesseract::Network::NumInputs |
( |
| ) |
const |
|
inline |
◆ NumOutputs()
| int tesseract::Network::NumOutputs |
( |
| ) |
const |
|
inline |
◆ OutputShape()
◆ Random()
| double tesseract::Network::Random |
( |
double |
range | ) |
|
|
protected |
◆ RemapOutputs()
| virtual int tesseract::Network::RemapOutputs |
( |
int |
old_no, |
|
|
const std::vector< int > & |
code_map |
|
) |
| |
|
inlinevirtual |
◆ Serialize()
| bool tesseract::Network::Serialize |
( |
TFile * |
fp | ) |
const |
|
virtual |
◆ SetEnableTraining()
| void tesseract::Network::SetEnableTraining |
( |
TrainingState |
state | ) |
|
|
virtual |
◆ SetNetworkFlags()
| void tesseract::Network::SetNetworkFlags |
( |
uint32_t |
flags | ) |
|
|
virtual |
◆ SetRandomizer()
| void tesseract::Network::SetRandomizer |
( |
TRand * |
randomizer | ) |
|
|
virtual |
◆ SetupNeedsBackprop()
| bool tesseract::Network::SetupNeedsBackprop |
( |
bool |
needs_backprop | ) |
|
|
virtual |
◆ spec()
| virtual STRING tesseract::Network::spec |
( |
| ) |
const |
|
inlinevirtual |
◆ TestFlag()
| bool tesseract::Network::TestFlag |
( |
NetworkFlags |
flag | ) |
const |
|
inline |
◆ type()
◆ Update()
| virtual void tesseract::Network::Update |
( |
float |
learning_rate, |
|
|
float |
momentum, |
|
|
float |
adam_beta, |
|
|
int |
num_samples |
|
) |
| |
|
inlinevirtual |
◆ XScaleFactor()
| virtual int tesseract::Network::XScaleFactor |
( |
| ) |
const |
|
inlinevirtual |
◆ backward_win_
◆ forward_win_
◆ name_
| STRING tesseract::Network::name_ |
|
protected |
◆ needs_to_backprop_
| bool tesseract::Network::needs_to_backprop_ |
|
protected |
◆ network_flags_
| int32_t tesseract::Network::network_flags_ |
|
protected |
◆ ni_
| int32_t tesseract::Network::ni_ |
|
protected |
◆ no_
| int32_t tesseract::Network::no_ |
|
protected |
◆ num_weights_
| int32_t tesseract::Network::num_weights_ |
|
protected |
◆ randomizer_
| TRand* tesseract::Network::randomizer_ |
|
protected |
◆ training_
◆ type_
The documentation for this class was generated from the following files: