19 #ifndef TESSERACT_LSTM_LSTMTRAINER_H_ 20 #define TESSERACT_LSTM_LSTMTRAINER_H_ 96 const char* model_base,
const char* checkpoint_name,
97 int debug_interval, int64_t max_memory);
174 const ImageData* trainingdata,
int iteration,
double min_dict_ratio,
175 double dict_ratio_step,
double max_dict_ratio,
double min_cert_offset,
176 double cert_offset_step,
double max_cert_offset,
STRING* results);
186 bool randomly_rotate);
263 if (image !=
nullptr) {
293 if (data.
empty())
return false;
488 #endif // TESSERACT_LSTM_LSTMTRAINER_H_
const UNICHARSET & GetUnicharset() const
void InitCharSet(const std::string &traineddata_path)
DocumentCache training_data_
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
bool TransitionTrainingStage(float error_threshold)
bool MaintainCheckpointsSpecific(int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg)
const DocumentCache & training_data() const
const double * error_rates() const
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
double worst_error_rates_[ET_COUNT]
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
bool SaveTraineddata(const STRING &filename)
int learning_iteration() const
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer *trainer) const
DocumentCache * mutable_training_data()
static const int kRollingBufferSize_
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
LSTMTrainer * sub_trainer_
int InitTensorFlowNetwork(const std::string &tf_proto)
int training_iteration() const
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
const ImageData * GetPageBySerial(int serial)
bool SimpleTextOutput() const
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
double ActivationError() const
double best_error_rates_[ET_COUNT]
int prev_sample_iteration_
int CurrentTrainingStage() const
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
void UpdateErrorBuffer(double new_error, ErrorTypes type)
int sample_iteration() const
void PrepareLogMsg(STRING *log_msg) const
double LastSingleError(ErrorTypes type) const
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
GenericVector< char > best_trainer_
TessResultCallback2< bool, const GenericVector< char > &, LSTMTrainer * > * CheckPointReader
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
double ComputeWinnerError(const NetworkIO &deltas)
GenericVector< double > best_error_history_
TessResultCallback3< bool, SerializeAmount, const LSTMTrainer *, GenericVector< char > * > * CheckPointWriter
void SaveRecognitionDump(GenericVector< char > *data) const
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
void FillErrorBuffer(double new_error, ErrorTypes type)
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
void set_perfect_delay(int delay)
int32_t improvement_steps() const
bool Init(const char *data_file_name)
int best_iteration() const
GenericVector< int > best_error_iterations_
void LogIterations(const char *intro_str, STRING *log_msg) const
int32_t improvement_steps_
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
bool(* FileReader)(const STRING &filename, GenericVector< char > *data)
TessResultCallback4< STRING, int, const double *, const TessdataManager &, int > * TestCallback
void StartSubtrainer(STRING *log_msg)
double NewSingleError(ErrorTypes type) const
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
GenericVector< char > best_model_data_
STRING DumpFilename() const
int checkpoint_iteration_
double best_error_rate() const
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
float error_rate_of_last_saved_best_
double learning_rate() const
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
int32_t sample_iteration_
void InitCharSet(const TessdataManager &mgr)
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
int last_perfect_training_iteration_
GenericVector< char > worst_model_data_
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
GenericVector< double > error_buffers_[ET_COUNT]
void SetupCheckpointInfo()
CheckPointReader checkpoint_reader_
double ComputeRMSError(const NetworkIO &deltas)
bool(* FileWriter)(const GenericVector< char > &data, const STRING &filename)
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
const GenericVector< char > & best_trainer() const
CheckPointWriter checkpoint_writer_
double error_rates_[ET_COUNT]