tesseract
5.0.0-alpha-619-ge9db
|
Go to the documentation of this file.
18 #define _USE_MATH_DEFINES // needed to get definition of M_SQRT1_2
22 #include "config_auto.h"
28 #include "allheaders.h"
36 #ifdef INCLUDE_TENSORFLOW
75 : randomly_rotate_(false),
77 sub_trainer_(nullptr) {
83 int debug_interval, int64_t max_memory)
84 : randomly_rotate_(false),
85 training_data_(max_memory),
86 sub_trainer_(nullptr) {
104 const char* old_traineddata) {
107 tprintf(
"Loaded file %s, unpacking...\n", filename);
110 if (((old_traineddata ==
nullptr || *old_traineddata ==
'\0') &&
112 filename == old_traineddata) {
117 if (old_traineddata ==
nullptr || *old_traineddata ==
'\0') {
118 tprintf(
"Must supply the old traineddata for code conversion!\n");
130 std::vector<int> code_map =
MapRecoder(old_chset, old_recoder);
147 int net_flags,
float weight_range,
148 float learning_rate,
float momentum,
156 append_index, net_flags, weight_range,
161 tprintf(
"Built network:%s from request %s\n",
164 "Training parameters:\n Debug interval = %d,"
165 " weights = %g, learning rate = %g, momentum=%g\n",
173 #ifdef INCLUDE_TENSORFLOW
176 TFNetwork* tf_net =
new TFNetwork(
"TensorFlow");
179 tprintf(
"InitFromProtoStr failed!!\n");
203 for (
int i = 0; i <
ET_COUNT; ++i) {
216 const ImageData* trainingdata,
int iteration,
double min_dict_ratio,
217 double dict_ratio_step,
double max_dict_ratio,
double min_cert_offset,
218 double cert_offset_step,
double max_cert_offset,
STRING* results) {
240 for (
double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
241 for (
double c = min_cert_offset; c < max_cert_offset;
242 c += cert_offset_step) {
244 search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
249 if ((r == min_dict_ratio && c == min_cert_offset) ||
250 !std::isfinite(word_error)) {
253 tprintf(
"r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
254 t.
c_str(), o.
c_str(), word_error, truth_labels[0]);
274 bool randomly_rotate) {
311 *log_msg +=
UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
324 *log_msg +=
" failed to write best model:";
326 *log_msg +=
" wrote best model:";
329 *log_msg += best_model_name;
334 *log_msg +=
UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
338 *log_msg +=
"\nDivergence! ";
355 result = sub_trainer_result !=
STR_NONE;
362 *log_msg +=
" failed to write checkpoint.";
364 *log_msg +=
" wrote checkpoint.";
385 *log_msg += intro_str;
411 if (!error_buffer.Serialize(fp))
return false;
415 uint8_t amount = serialize_amount;
416 if (!fp->
Serialize(&amount))
return false;
417 if (serialize_amount ==
LIGHT)
return true;
432 if (!sub_data.
Serialize(fp))
return false;
446 tprintf(
"Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
455 if (!error_buffer.DeSerialize(fp))
return false;
461 if (amount ==
LIGHT)
return true;
475 if (sub_data.
empty()) {
493 *log_msg +=
" Failed to revert to previous best for trial!";
497 log_msg->
add_str_int(
" Trial sub_trainer_ from iteration ",
522 double sub_margin = (training_error - sub_error) / sub_error;
531 int target_iteration =
536 STRING batch_log =
"Sub:";
540 *log_msg += batch_log;
542 sub_margin = (training_error - sub_error) / sub_error;
550 log_msg->
add_str_int(
" Sub trainer wins at iteration ",
567 log_msg->
add_str_int(
"\nReduced learning rate on layers: ", num_reduced);
589 int num_layers = layers.
size();
594 for (
int i = 0; i < LR_COUNT; ++i) {
598 double momentum_factor = 1.0 / (1.0 -
momentum_);
601 for (
int i = 0; i < num_layers; ++i) {
606 for (
int s = 0; s < num_samples; ++s) {
608 for (
int ww = 0; ww < LR_COUNT; ++ww) {
610 float ww_factor = momentum_factor;
611 if (ww == LR_DOWN) ww_factor *= factor;
618 for (
int i = 0; i < num_layers; ++i) {
619 if (num_weights[i] == 0)
continue;
627 if (trainingdata ==
nullptr)
continue;
631 for (
int i = 0; i < num_layers; ++i) {
632 if (num_weights[i] == 0)
continue;
640 layer->
Update(0.0, 0.0, 0.0, 0);
644 float before_bad = bad_sums[ww][i];
645 float before_ok = ok_sums[ww][i];
647 &ok_sums[ww][i], &bad_sums[ww][i]);
649 bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
651 bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
657 for (
int i = 0; i < num_layers; ++i) {
658 if (num_weights[i] == 0)
continue;
661 double total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
662 double total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
663 double frac_down = bad_sums[LR_DOWN][i] / total_down;
664 double frac_same = bad_sums[LR_SAME][i] / total_same;
666 lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
675 if (num_lowered == 0) {
677 for (
int i = 0; i < num_layers; ++i) {
678 if (num_weights[i] > 0) {
694 tprintf(
"Empty truth string!\n");
702 if (unicharset.
encode_string(cleaned.c_str(),
true, &internal_labels,
nullptr,
705 for (
int i = 0; i < internal_labels.
size(); ++i) {
706 if (recoder !=
nullptr) {
711 for (
int j = 0; j < len; ++j) {
725 if (success)
return true;
727 tprintf(
"Encoding of string failed! Failure bytes:");
728 while (err_index < cleaned.size()) {
729 tprintf(
" %x", cleaned[err_index++] & 0xff);
758 #ifndef GRAPHICS_DISABLED
762 #endif // GRAPHICS_DISABLED
773 if (trainingdata ==
nullptr) {
774 tprintf(
"Null trainingdata.\n");
782 tprintf(
"Can't encode transcription: '%s' in language '%s'\n",
787 bool upside_down =
false;
797 for (
int c = 0; c < truth_labels.
size(); ++c) {
805 while (w < truth_labels.
size() &&
808 if (w == truth_labels.
size()) {
809 tprintf(
"Blank transcription: %s\n",
816 if (!
RecognizeLine(*trainingdata, invert, debug, invert, upside_down,
817 &image_scale, &inputs, fwd_outputs)) {
818 tprintf(
"Image not trainable\n");
825 tprintf(
"Compute simple targets failed!\n");
828 }
else if (loss_type ==
LT_CTC) {
830 tprintf(
"Compute CTC targets failed!\n");
834 tprintf(
"Logistic outputs not implemented yet!\n");
841 if (loss_type !=
LT_CTC) {
853 if (truth_text != ocr_text) {
854 tprintf(
"Iteration %d: BEST OCR TEXT : %s\n",
863 trainingdata->
page_number(), delta_error == 0.0 ?
"(Perfect)" :
"");
865 if (delta_error == 0.0)
return PERFECT;
884 const char* data,
int size) {
886 tprintf(
"Warning: data size is 0 in LSTMTrainer::ReadLocalTrainingDump\n");
899 recognizer_data.
size());
920 filename +=
".checkpoint";
937 std::vector<int> code_map(num_new_codes, -1);
938 for (
int c = 0; c < num_new_codes; ++c) {
942 for (
int uid = 0; uid <= num_new_unichars; ++uid) {
946 while (code_index < length && codes(code_index) != c) ++code_index;
947 if (code_index == length)
continue;
950 uid < num_new_unichars
952 : old_chset.
size() - 1;
953 if (old_uid == INVALID_UNICHAR_ID)
continue;
956 if (code_index < old_recoder.
EncodeUnichar(old_uid, &old_codes)) {
957 old_code = old_codes(code_index);
961 code_map[c] = old_code;
974 "Must provide a traineddata containing lstm_unicharset and"
975 " lstm_recoder!\n" !=
nullptr);
1011 if (truth_text.
c_str() ==
nullptr || truth_text.
length() <= 0) {
1012 tprintf(
"Empty truth string at decode time!\n");
1021 tprintf(
"Iteration %d: GROUND TRUTH : %s\n",
1023 if (truth_text != text) {
1024 tprintf(
"Iteration %d: ALIGNED TRUTH : %s\n",
1028 tprintf(
"TRAINING activation path for truth string %s\n",
1029 truth_text.
c_str());
1043 const char* window_name,
ScrollView** window) {
1044 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics.
1045 int width = targets.
Width();
1049 for (
int c = 0; c < num_features; ++c) {
1051 (*window)->Pen(static_cast<ScrollView::Color>(color));
1053 for (
int t = 0; t < width; ++t) {
1054 double target = targets.
f(t)[c];
1058 (*window)->SetCursor(t - 1, 0);
1061 (*window)->DrawTo(t, target);
1062 }
else if (start_t >= 0) {
1063 (*window)->DrawTo(t, 0);
1064 (*window)->DrawTo(start_t - 1, 0);
1069 (*window)->DrawTo(width, 0);
1070 (*window)->DrawTo(start_t - 1, 0);
1073 (*window)->Update();
1074 #endif // GRAPHICS_DISABLED
1082 if (truth_labels.
size() > targets->
Width()) {
1083 tprintf(
"Error: transcription %s too long to fit into target of width %d\n",
1087 for (
int i = 0; i < truth_labels.
size() && i < targets->
Width(); ++i) {
1090 for (
int i = truth_labels.
size(); i < targets->
Width(); ++i) {
1111 double char_error,
double word_error) {
1131 double total_error = 0.0;
1132 int width = deltas.
Width();
1134 for (
int t = 0; t < width; ++t) {
1135 const float* class_errs = deltas.
f(t);
1136 for (
int c = 0; c < num_classes; ++c) {
1137 double error = class_errs[c];
1138 total_error += error * error;
1141 return sqrt(total_error / (width * num_classes));
1151 int width = deltas.
Width();
1153 for (
int t = 0; t < width; ++t) {
1154 const float* class_errs = deltas.
f(t);
1155 for (
int c = 0; c < num_classes; ++c) {
1156 float abs_delta = fabs(class_errs[c]);
1159 if (0.5 <= abs_delta)
1163 return static_cast<double>(num_errors) / width;
1172 for (
int i = 0; i < truth_str.
size(); ++i) {
1174 ++label_counts[truth_str[i]];
1178 for (
int i = 0; i < ocr_str.
size(); ++i) {
1180 --label_counts[ocr_str[i]];
1183 int char_errors = 0;
1184 for (
int i = 0; i < label_counts.
size(); ++i) {
1185 char_errors += abs(label_counts[i]);
1187 if (truth_size == 0) {
1188 return (char_errors == 0) ? 0.0 : 1.0;
1190 return static_cast<double>(char_errors) / truth_size;
1196 using StrMap = std::unordered_map<std::string, int, std::hash<std::string>>;
1198 truth_str->
split(
' ', &truth_words);
1199 if (truth_words.
empty())
return 0.0;
1200 ocr_str->
split(
' ', &ocr_words);
1202 for (
int i = 0; i < truth_words.
size(); ++i) {
1204 auto it = word_counts.find(truth_word);
1205 if (it == word_counts.end())
1206 word_counts.insert(std::make_pair(truth_word, 1));
1210 for (
int i = 0; i < ocr_words.
size(); ++i) {
1212 auto it = word_counts.find(ocr_word);
1213 if (it == word_counts.end())
1214 word_counts.insert(std::make_pair(ocr_word, -1));
1218 int word_recall_errs = 0;
1219 for (StrMap::const_iterator it = word_counts.begin(); it != word_counts.end();
1221 if (it->second > 0) word_recall_errs += it->second;
1223 return static_cast<double>(word_recall_errs) / truth_words.
size();
1233 double buffer_sum = 0.0;
1235 double mean = buffer_sum / mean_count;
1249 tprintf(
"Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1297 double two_percent_more = error_rate + 2.0;
1304 tprintf(
"2 Percent improvement time=%d, best error was %g @ %d\n",
1309 if (tester !=
nullptr) {
bool load_from_file(const char *const filename, bool skip_fragments)
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum, float adam_beta)
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
static constexpr float kMinCertainty
const double kSubTrainerMarginFraction
bool LoadAllTrainingData(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, bool randomly_rotate)
const double kMinDivergenceRate
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)=0
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
void add_str_int(const char *str, int number)
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
bool encode_string(const char *str, bool give_up_on_failure, GenericVector< UNICHAR_ID > *encoding, GenericVector< char > *lengths, int *encoded_length) const
void SetVersionString(const std::string &v_str)
virtual void SetEnableTraining(TrainingState state)
void SaveRecognitionDump(GenericVector< char > *data) const
void ScaleLayerLearningRate(const STRING &id, double factor)
bool DeSerialize(TFile *fp)
double best_error_rates_[ET_COUNT]
std::vector< int > MapRecoder(const UNICHARSET &old_chset, const UnicharCompress &old_recoder) const
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
int prev_sample_iteration_
int training_iteration() const
std::string VersionString() const
bool LoadDocuments(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, FileReader reader)
constexpr size_t countof(T const (&)[N]) noexcept
STRING DumpFilename() const
double learning_rate() const
const int kNumAdjustmentIterations
virtual int RemapOutputs(int old_no, const std::vector< int > &code_map)
GenericVector< double > error_buffers_[ET_COUNT]
bool Serialize(FILE *fp) const
const STRING & imagefilename() const
bool SaveDataToFile(const GenericVector< char > &data, const char *filename)
int32_t sample_iteration_
const GENERIC_2D_ARRAY< float > & float_array() const
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
int IntCastRounded(double x)
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset, int lstm_choice_mode=0)
GenericVector< STRING > EnumerateLayers() const
void LabelsFromOutputs(const NetworkIO &outputs, GenericVector< int > *labels, GenericVector< int > *xcoords)
const int kMinStartedErrorRate
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
const UNICHARSET & GetUnicharset() const
virtual void DebugWeights()=0
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
GenericVector< double > best_error_history_
bool Open(const STRING &filename, FileReader reader)
bool TestFlag(NetworkFlags flag) const
float GetLayerLearningRate(const STRING &id) const
GenericVector< char > best_trainer_
void FillErrorBuffer(double new_error, ErrorTypes type)
void StartSubtrainer(STRING *log_msg)
void SetActivations(int t, int label, float ok_score)
void OverwriteEntry(TessdataType type, const char *data, int size)
bool LoadCharsets(const TessdataManager *mgr)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
const GenericVector< TBOX > & boxes() const
void SetIteration(int iteration)
float error_rate_of_last_saved_best_
bool Serialize(SerializeAmount serialize_amount, const TessdataManager *mgr, TFile *fp) const
double NewSingleError(ErrorTypes type) const
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, const TBOX &line_box, PointerVector< WERD_RES > *words, int lstm_choice_mode=0, int lstm_choice_amount=5)
NetworkScratch scratch_space_
bool ReadLocalTrainingDump(const TessdataManager *mgr, const char *data, int size)
const STRING & transcription() const
const char * c_str() const
bool DeSerialize(char *data, size_t count=1)
bool DeSerialize(bool swap, FILE *fp)
bool Serialize(const char *data, size_t count=1)
GenericVector< char > best_model_data_
void UpdateErrorBuffer(double new_error, ErrorTypes type)
int learning_iteration() const
double ComputeRMSError(const NetworkIO &deltas)
bool GetComponent(TessdataType type, TFile *fp)
UNICHAR_ID unichar_to_id(const char *const unichar_repr) const
LossType OutputLossType() const
void ScaleLearningRate(double factor)
virtual StaticShape InputShape() const
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
const int kNumPagesPerBatch
int32_t improvement_steps_
Network * GetLayer(const STRING &id) const
const double kStageTransitionThreshold
bool Serialize(const TessdataManager *mgr, TFile *fp) const
static std::string CleanupString(const char *utf8_str)
int32_t training_iteration_
bool TryLoadingCheckpoint(const char *filename, const char *old_traineddata)
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer) const
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
const double kLearningRateDecay
bool SimpleTextOutput() const
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
bool has_special_codes() const
double SignedRand(double range)
int InitTensorFlowNetwork(const std::string &tf_proto)
LSTMTrainer * sub_trainer_
bool DeSerialize(const TessdataManager *mgr, TFile *fp)
const double kImprovementFraction
void SubtractAllFromFloat(const NetworkIO &src)
SVEvent * AwaitEvent(SVEventType type)
const double kBestCheckpointFraction
void Resize(const NetworkIO &src, int num_features)
const STRING & name() const
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
virtual void Update(float learning_rate, float momentum, float adam_beta, int num_samples)
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
bool Init(const char *data_file_name)
static void NormalizeProbs(NetworkIO *probs)
void add_str_double(const char *str, double number)
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
const int kMinStallIterations
void init_to_size(int size, const T &t)
const STRING & language() const
int last_perfect_training_iteration_
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
bool AnySuspiciousTruth(float confidence_thr) const
int CurrentTrainingStage() const
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
DLLSYM void tprintf(const char *format,...)
GenericVector< int > best_error_iterations_
const int kErrorGraphInterval
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
virtual STRING spec() const
GenericVector< char > worst_model_data_
double worst_error_rates_[ET_COUNT]
const double kHighConfidence
static const int kRollingBufferSize_
double error_rates_[ET_COUNT]
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
DocumentCache training_data_
bool TransitionTrainingStage(float error_threshold)
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
void PrepareLogMsg(STRING *log_msg) const
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
void OpenWrite(GenericVector< char > *data)
bool SaveTraineddata(const STRING &filename)
int checkpoint_iteration_
virtual void CountAlternators(const Network &other, double *same, double *changed) const
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
std::function< STRING(int, const double *, const TessdataManager &, int)> TestCallback
LIST search(LIST list, void *key, int_compare is_equal)
bool SaveFile(const STRING &filename, FileWriter writer) const
void split(char c, GenericVector< STRING > *splited)
double ComputeWinnerError(const NetworkIO &deltas)
void ExtractBestPathAsLabels(GenericVector< int > *labels, GenericVector< int > *xcoords) const
void LogIterations(const char *intro_str, STRING *log_msg) const
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
int sample_iteration() const
STRING DecodeLabels(const GenericVector< int > &labels)