tesseract
5.0.0-alpha-619-ge9db
|
Go to the documentation of this file.
18 #ifndef TESSERACT_LSTM_LSTMTRAINER_H_
19 #define TESSERACT_LSTM_LSTMTRAINER_H_
82 LSTMTrainer(
const char* model_base,
const char* checkpoint_name,
83 int debug_interval, int64_t max_memory);
160 const ImageData* trainingdata,
int iteration,
double min_dict_ratio,
161 double dict_ratio_step,
double max_dict_ratio,
double min_cert_offset,
162 double cert_offset_step,
double max_cert_offset,
STRING* results);
172 bool randomly_rotate);
249 if (image !=
nullptr) {
279 if (data.
empty())
return false;
467 #endif // TESSERACT_LSTM_LSTMTRAINER_H_
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
void InitCharSet(const TessdataManager &mgr)
double ActivationError() const
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
void SaveRecognitionDump(GenericVector< char > *data) const
DocumentCache * mutable_training_data()
bool ReadSizedTrainingDump(const char *data, int size, LSTMTrainer *trainer) const
double best_error_rates_[ET_COUNT]
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
int prev_sample_iteration_
int best_iteration() const
int training_iteration() const
void SetupCheckpointInfo()
STRING DumpFilename() const
double learning_rate() const
GenericVector< double > error_buffers_[ET_COUNT]
int32_t sample_iteration_
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
double LastSingleError(ErrorTypes type) const
const UNICHARSET & GetUnicharset() const
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
GenericVector< double > best_error_history_
GenericVector< char > best_trainer_
void FillErrorBuffer(double new_error, ErrorTypes type)
void StartSubtrainer(STRING *log_msg)
float error_rate_of_last_saved_best_
bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
double NewSingleError(ErrorTypes type) const
const ImageData * GetPageBySerial(int serial)
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
GenericVector< char > best_model_data_
void UpdateErrorBuffer(double new_error, ErrorTypes type)
int learning_iteration() const
double ComputeRMSError(const NetworkIO &deltas)
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
int32_t improvement_steps_
const DocumentCache & training_data() const
int32_t improvement_steps() const
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
void set_perfect_delay(int delay)
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
bool SimpleTextOutput() const
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
const double * error_rates() const
int InitTensorFlowNetwork(const std::string &tf_proto)
LSTMTrainer * sub_trainer_
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
void InitCharSet(const std::string &traineddata_path)
double best_error_rate() const
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
bool Init(const char *data_file_name)
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
int last_perfect_training_iteration_
const GenericVector< char > & best_trainer() const
int CurrentTrainingStage() const
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
GenericVector< int > best_error_iterations_
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
GenericVector< char > worst_model_data_
double worst_error_rates_[ET_COUNT]
static const int kRollingBufferSize_
double error_rates_[ET_COUNT]
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
DocumentCache training_data_
bool TransitionTrainingStage(float error_threshold)
bool MaintainCheckpointsSpecific(int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg)
void PrepareLogMsg(STRING *log_msg) const
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
bool SaveTraineddata(const STRING &filename)
int checkpoint_iteration_
std::function< STRING(int, const double *, const TessdataManager &, int)> TestCallback
double ComputeWinnerError(const NetworkIO &deltas)
void LogIterations(const char *intro_str, STRING *log_msg) const
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
int sample_iteration() const