19 #ifndef TESSERACT_LSTM_NETWORK_H_ 20 #define TESSERACT_LSTM_NETWORK_H_ 186 virtual int RemapOutputs(
int old_no,
const std::vector<int>& code_map) {
219 tprintf(
"Must override Network::DebugWeights for type %d\n",
type_);
231 virtual void Update(
float learning_rate,
float momentum,
float adam_beta,
237 double* changed)
const {}
265 tprintf(
"Must override Network::Forward for type %d\n",
type_);
276 tprintf(
"Must override Network::Backward for type %d\n",
type_);
287 static void ClearWindow(
bool tess_coords,
const char* window_name,
296 double Random(
double range);
320 #endif // TESSERACT_LSTM_NETWORK_H_
virtual void CacheXScaleFactor(int factor)
double Random(double range)
static Network * CreateFromFile(TFile *fp)
virtual void CountAlternators(const Network &other, double *same, double *changed) const
bool needs_to_backprop() const
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
ScrollView * forward_win_
virtual STRING spec() const
void DisplayBackward(const NetworkIO &matrix)
virtual bool DeSerialize(TFile *fp)
virtual void SetRandomizer(TRand *randomizer)
virtual int InitWeights(float range, TRand *randomizer)
virtual int XScaleFactor() const
virtual void SetEnableTraining(TrainingState state)
virtual bool SetupNeedsBackprop(bool needs_backprop)
virtual StaticShape InputShape() const
virtual bool Serialize(TFile *fp) const
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
void DisplayForward(const NetworkIO &matrix)
const STRING & name() const
static char const *const kTypeNames[NT_COUNT]
DLLSYM void tprintf(const char *format,...)
virtual void ConvertToInt()
virtual ~Network()=default
virtual StaticShape OutputShape(const StaticShape &input_shape) const
static int DisplayImage(Pix *pix, ScrollView *window)
ScrollView * backward_win_
void set_depth(int value)
virtual void SetNetworkFlags(uint32_t flags)
virtual bool IsPlumbingType() const
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
bool TestFlag(NetworkFlags flag) const
virtual void DebugWeights()
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)