49 for (
int net_idx = 0; net_idx < nets_.size(); net_idx++) {
50 if (nets_[net_idx] !=
NULL) {
51 delete nets_[net_idx];
56 if (net_input_ !=
NULL) {
61 if (net_output_ !=
NULL) {
83 void HybridNeuralNetCharClassifier::Fold() {
88 for (
int class_id = 0; class_id < class_cnt; class_id++) {
93 for (
int ch = 0; ch < upper_form32.length(); ch++) {
94 if (iswalpha(static_cast<int>(upper_form32[ch])) != 0) {
95 upper_form32[ch] = towupper(upper_form32[ch]);
102 upper_form32.c_str()));
103 if (upper_class_id != -1 && class_id != upper_class_id) {
104 float max_out =
MAX(net_output_[class_id], net_output_[upper_class_id]);
105 net_output_[class_id] = max_out;
106 net_output_[upper_class_id] = max_out;
115 for (
int fold_set = 0; fold_set <
fold_set_cnt_; fold_set++) {
116 float max_prob = net_output_[
fold_sets_[fold_set][0]];
119 if (net_output_[fold_sets_[fold_set][ch]] > max_prob) {
120 max_prob = net_output_[fold_sets_[fold_set][ch]];
123 for (
int ch = 0; ch < fold_set_len_[fold_set]; ch++) {
124 net_output_[fold_sets_[fold_set][ch]] =
MAX(max_prob * kFoldingRatio,
125 net_output_[fold_sets_[fold_set][ch]]);
132 bool HybridNeuralNetCharClassifier::RunNets(CharSamp *char_samp) {
137 if (net_input_ ==
NULL) {
138 net_input_ =
new float[feat_cnt];
139 if (net_input_ ==
NULL) {
143 net_output_ =
new float[class_cnt];
144 if (net_output_ ==
NULL) {
155 memset(net_output_, 0, class_cnt *
sizeof(*net_output_));
156 float *inputs = net_input_;
157 for (
int net_idx = 0; net_idx < nets_.size(); net_idx++) {
159 vector<float> net_out(class_cnt, 0.0);
160 if (!nets_[net_idx]->FeedForward(inputs, &net_out[0])) {
164 for (
int class_idx = 0; class_idx < class_cnt; class_idx++) {
165 net_output_[class_idx] += (net_out[class_idx] * net_wgts_[net_idx]);
168 inputs += nets_[net_idx]->in_cnt();
180 if (RunNets(char_samp) ==
false) {
191 if (RunNets(char_samp) ==
false) {
199 if (alt_list ==
NULL) {
203 for (
int out = 1; out < class_cnt; out++) {
205 alt_list->
Insert(out, cost);
218 bool HybridNeuralNetCharClassifier::LoadFoldingSets(
219 const string &data_file_path,
const string &
lang,
LangModel *lang_mod) {
221 string fold_file_name;
222 fold_file_name = data_file_path +
lang;
223 fold_file_name +=
".cube.fold";
226 FILE *fp = fopen(fold_file_name.c_str(),
"rb");
232 string fold_sets_str;
239 vector<string> str_vec;
241 fold_set_cnt_ = str_vec.size();
243 if (fold_sets_ ==
NULL) {
247 if (fold_set_len_ ==
NULL) {
252 for (
int fold_set = 0; fold_set <
fold_set_cnt_; fold_set++) {
253 reinterpret_cast<TessLangModel *
>(lang_mod)->RemoveInvalidCharacters(
257 if (str_vec[fold_set].length() <= 1) {
258 fprintf(stderr,
"Cube WARNING (ConvNetCharClassifier::LoadFoldingSets): "
259 "invalidating folding set %d\n", fold_set);
260 fold_set_len_[fold_set] = 0;
261 fold_sets_[fold_set] =
NULL;
267 fold_set_len_[fold_set] = str32.length();
268 fold_sets_[fold_set] =
new int[fold_set_len_[fold_set]];
269 if (fold_sets_[fold_set] ==
NULL) {
270 fprintf(stderr,
"Cube ERROR (ConvNetCharClassifier::LoadFoldingSets): "
271 "could not allocate folding set\n");
272 fold_set_cnt_ = fold_set;
275 for (
int ch = 0; ch < fold_set_len_[fold_set]; ch++) {
283 bool HybridNeuralNetCharClassifier::Init(
const string &data_file_path,
285 LangModel *lang_mod) {
292 if (!LoadNets(data_file_path, lang)) {
298 if (!LoadFoldingSets(data_file_path, lang, lang_mod)) {
309 bool HybridNeuralNetCharClassifier::LoadNets(
const string &data_file_path,
310 const string &lang) {
311 string hybrid_net_file;
312 string junk_net_file;
315 hybrid_net_file = data_file_path +
lang;
316 hybrid_net_file +=
".cube.hybrid";
319 FILE *fp = fopen(hybrid_net_file.c_str(),
"rb");
331 vector<string> str_vec;
333 if (str_vec.size() <= 0) {
338 nets_.resize(str_vec.size(),
NULL);
339 net_wgts_.resize(str_vec.size(), 0);
340 int total_input_size = 0;
341 for (
int net_idx = 0; net_idx < str_vec.size(); net_idx++) {
343 vector<string> tokens_vec;
346 if (tokens_vec.size() != 2) {
350 string net_file_name = data_file_path + tokens_vec[0];
352 if (nets_[net_idx] ==
NULL) {
356 net_wgts_[net_idx] = atof(tokens_vec[1].c_str());
357 if (net_wgts_[net_idx] < 0.0) {
360 total_input_size += nets_[net_idx]->in_cnt();
bool Insert(int class_id, int cost, void *tag=NULL)
static int Prob2Cost(double prob_val)
virtual bool ComputeFeatures(CharSamp *char_samp, float *features)=0
basic_string< char_32 > string_32
static bool ReadFileToString(const string &file_name, string *str)
virtual bool Train(CharSamp *char_samp, int ClassID)
virtual CharAltList * Classify(CharSamp *char_samp)
FeatureBase * feat_extract_
virtual ~HybridNeuralNetCharClassifier()
void SetNet(tesseract::NeuralNet *net)
static void UTF8ToUTF32(const char *utf8_str, string_32 *str32)
int ClassID(const char_32 *str) const
HybridNeuralNetCharClassifier(CharSet *char_set, TuningParams *params, FeatureBase *feat_extract)
virtual int CharCost(CharSamp *char_samp)
virtual bool SetLearnParam(char *var_name, float val)
static void SplitStringUsing(const string &str, const string &delims, vector< string > *str_vec)
virtual int FeatureCnt()=0
const char_32 * ClassString(int class_id) const
static NeuralNet * FromFile(const string file_name)