#include <network.h>
|
| Network () |
|
| Network (NetworkType type, const STRING &name, int ni, int no) |
|
virtual | ~Network ()=default |
|
NetworkType | type () const |
|
bool | IsTraining () const |
|
bool | needs_to_backprop () const |
|
int | num_weights () const |
|
int | NumInputs () const |
|
int | NumOutputs () const |
|
virtual StaticShape | InputShape () const |
|
virtual StaticShape | OutputShape (const StaticShape &input_shape) const |
|
const STRING & | name () const |
|
virtual STRING | spec () const |
|
bool | TestFlag (NetworkFlags flag) const |
|
virtual bool | IsPlumbingType () const |
|
virtual void | SetEnableTraining (TrainingState state) |
|
virtual void | SetNetworkFlags (uint32_t flags) |
|
virtual int | InitWeights (float range, TRand *randomizer) |
|
virtual int | RemapOutputs (int old_no, const std::vector< int > &code_map) |
|
virtual void | ConvertToInt () |
|
virtual void | SetRandomizer (TRand *randomizer) |
|
virtual bool | SetupNeedsBackprop (bool needs_backprop) |
|
virtual int | XScaleFactor () const |
|
virtual void | CacheXScaleFactor (int factor) |
|
virtual void | DebugWeights ()=0 |
|
virtual bool | Serialize (TFile *fp) const |
|
virtual bool | DeSerialize (TFile *fp)=0 |
|
virtual void | Update (float learning_rate, float momentum, float adam_beta, int num_samples) |
|
virtual void | CountAlternators (const Network &other, double *same, double *changed) const |
|
virtual void | Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)=0 |
|
virtual bool | Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0 |
|
void | DisplayForward (const NetworkIO &matrix) |
|
void | DisplayBackward (const NetworkIO &matrix) |
|
Definition at line 105 of file network.h.
◆ Network() [1/2]
tesseract::Network::Network |
( |
| ) |
|
◆ Network() [2/2]
tesseract::Network::Network |
( |
NetworkType |
type, |
|
|
const STRING & |
name, |
|
|
int |
ni, |
|
|
int |
no |
|
) |
| |
◆ ~Network()
virtual tesseract::Network::~Network |
( |
| ) |
|
|
virtualdefault |
◆ Backward()
◆ CacheXScaleFactor()
virtual void tesseract::Network::CacheXScaleFactor |
( |
int |
factor | ) |
|
|
inlinevirtual |
◆ ClearWindow()
void tesseract::Network::ClearWindow |
( |
bool |
tess_coords, |
|
|
const char * |
window_name, |
|
|
int |
width, |
|
|
int |
height, |
|
|
ScrollView ** |
window |
|
) |
| |
|
static |
Definition at line 312 of file network.cpp.
314 if (*window ==
nullptr) {
315 int min_size = std::min(width, height);
317 if (min_size < 1) min_size = 1;
325 *window =
new ScrollView(window_name, 80, 100, width, height, width, height,
327 tprintf(
"Created window %s of size %d, %d\n", window_name, width, height);
◆ ConvertToInt()
virtual void tesseract::Network::ConvertToInt |
( |
| ) |
|
|
inlinevirtual |
◆ CountAlternators()
virtual void tesseract::Network::CountAlternators |
( |
const Network & |
other, |
|
|
double * |
same, |
|
|
double * |
changed |
|
) |
| const |
|
inlinevirtual |
◆ CreateFromFile()
Definition at line 187 of file network.cpp.
191 int32_t network_flags;
198 type = getNetworkType(fp);
199 if (!fp->DeSerialize(&data))
return nullptr;
201 if (!fp->DeSerialize(&data))
return nullptr;
203 if (!fp->DeSerialize(&network_flags))
return nullptr;
204 if (!fp->DeSerialize(&ni))
return nullptr;
205 if (!fp->DeSerialize(&no))
return nullptr;
206 if (!fp->DeSerialize(&
num_weights))
return nullptr;
211 network =
new Convolve(
name, ni, 0, 0);
214 network =
new Input(
name, ni, no);
221 new LSTM(
name, ni, no, no,
false,
type);
224 network =
new Maxpool(
name, ni, 0, 0);
235 network =
new Reconfig(
name, ni, 0, 0);
244 network =
new Series(
name);
247 #ifdef INCLUDE_TENSORFLOW
248 network =
new TFNetwork(
name);
250 tprintf(
"TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
262 network =
new FullyConnected(
name, ni, no,
type);
268 network->training_ = training;
270 network->network_flags_ = network_flags;
272 if (!network->DeSerialize(fp)) {
◆ DebugWeights()
virtual void tesseract::Network::DebugWeights |
( |
| ) |
|
|
pure virtual |
◆ DeSerialize()
virtual bool tesseract::Network::DeSerialize |
( |
TFile * |
fp | ) |
|
|
pure virtual |
◆ DisplayBackward()
void tesseract::Network::DisplayBackward |
( |
const NetworkIO & |
matrix | ) |
|
Definition at line 299 of file network.cpp.
300 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
301 Pix* image = matrix.ToPix();
307 #endif // GRAPHICS_DISABLED
◆ DisplayForward()
void tesseract::Network::DisplayForward |
( |
const NetworkIO & |
matrix | ) |
|
Definition at line 288 of file network.cpp.
289 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
290 Pix* image = matrix.ToPix();
295 #endif // GRAPHICS_DISABLED
◆ DisplayImage()
int tesseract::Network::DisplayImage |
( |
Pix * |
pix, |
|
|
ScrollView * |
window |
|
) |
| |
|
static |
Definition at line 335 of file network.cpp.
336 int height = pixGetHeight(pix);
337 window->
Image(pix, 0, 0);
◆ Forward()
◆ InitWeights()
int tesseract::Network::InitWeights |
( |
float |
range, |
|
|
TRand * |
randomizer |
|
) |
| |
|
virtual |
◆ InputShape()
virtual StaticShape tesseract::Network::InputShape |
( |
| ) |
const |
|
inlinevirtual |
◆ IsPlumbingType()
virtual bool tesseract::Network::IsPlumbingType |
( |
| ) |
const |
|
inlinevirtual |
◆ IsTraining()
bool tesseract::Network::IsTraining |
( |
| ) |
const |
|
inline |
◆ name()
const STRING& tesseract::Network::name |
( |
| ) |
const |
|
inline |
◆ needs_to_backprop()
bool tesseract::Network::needs_to_backprop |
( |
| ) |
const |
|
inline |
◆ num_weights()
int tesseract::Network::num_weights |
( |
| ) |
const |
|
inline |
◆ NumInputs()
int tesseract::Network::NumInputs |
( |
| ) |
const |
|
inline |
◆ NumOutputs()
int tesseract::Network::NumOutputs |
( |
| ) |
const |
|
inline |
◆ OutputShape()
◆ Random()
double tesseract::Network::Random |
( |
double |
range | ) |
|
|
protected |
◆ RemapOutputs()
virtual int tesseract::Network::RemapOutputs |
( |
int |
old_no, |
|
|
const std::vector< int > & |
code_map |
|
) |
| |
|
inlinevirtual |
◆ Serialize()
bool tesseract::Network::Serialize |
( |
TFile * |
fp | ) |
const |
|
virtual |
◆ SetEnableTraining()
void tesseract::Network::SetEnableTraining |
( |
TrainingState |
state | ) |
|
|
virtual |
◆ SetNetworkFlags()
void tesseract::Network::SetNetworkFlags |
( |
uint32_t |
flags | ) |
|
|
virtual |
◆ SetRandomizer()
void tesseract::Network::SetRandomizer |
( |
TRand * |
randomizer | ) |
|
|
virtual |
◆ SetupNeedsBackprop()
bool tesseract::Network::SetupNeedsBackprop |
( |
bool |
needs_backprop | ) |
|
|
virtual |
◆ spec()
virtual STRING tesseract::Network::spec |
( |
| ) |
const |
|
inlinevirtual |
◆ TestFlag()
bool tesseract::Network::TestFlag |
( |
NetworkFlags |
flag | ) |
const |
|
inline |
◆ type()
◆ Update()
virtual void tesseract::Network::Update |
( |
float |
learning_rate, |
|
|
float |
momentum, |
|
|
float |
adam_beta, |
|
|
int |
num_samples |
|
) |
| |
|
inlinevirtual |
◆ XScaleFactor()
virtual int tesseract::Network::XScaleFactor |
( |
| ) |
const |
|
inlinevirtual |
◆ backward_win_
◆ forward_win_
◆ name_
STRING tesseract::Network::name_ |
|
protected |
◆ needs_to_backprop_
bool tesseract::Network::needs_to_backprop_ |
|
protected |
◆ network_flags_
int32_t tesseract::Network::network_flags_ |
|
protected |
◆ ni_
int32_t tesseract::Network::ni_ |
|
protected |
◆ no_
int32_t tesseract::Network::no_ |
|
protected |
◆ num_weights_
int32_t tesseract::Network::num_weights_ |
|
protected |
◆ randomizer_
TRand* tesseract::Network::randomizer_ |
|
protected |
◆ training_
◆ type_
The documentation for this class was generated from the following files: