tesseract
5.0.0-alpha-619-ge9db
ctc.h
Go to the documentation of this file.
1
// File: ctc.h
3
// Description: Slightly improved standard CTC to compute the targets.
4
// Author: Ray Smith
5
// Created: Wed Jul 13 15:17:06 PDT 2016
6
//
7
// (C) Copyright 2016, Google Inc.
8
// Licensed under the Apache License, Version 2.0 (the "License");
9
// you may not use this file except in compliance with the License.
10
// You may obtain a copy of the License at
11
// http://www.apache.org/licenses/LICENSE-2.0
12
// Unless required by applicable law or agreed to in writing, software
13
// distributed under the License is distributed on an "AS IS" BASIS,
14
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
// See the License for the specific language governing permissions and
16
// limitations under the License.
18
19
#ifndef TESSERACT_LSTM_CTC_H_
20
#define TESSERACT_LSTM_CTC_H_
21
22
#include <
tesseract/genericvector.h
>
23
#include "
network.h
"
24
#include "
networkio.h
"
25
#include "
scrollview.h
"
26
27
namespace
tesseract
{
28
29
// Class to encapsulate CTC and simple target generation.
30
class
CTC
{
31
public
:
32
// Normalizes the probabilities such that no target has a prob below min_prob,
33
// and, provided that the initial total is at least min_total_prob, then all
34
// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
35
// probability is thus 1 - (num_classes-1)*min_prob.
36
static
void
NormalizeProbs
(
NetworkIO
* probs) {
37
NormalizeProbs
(probs->
mutable_float_array
());
38
}
39
40
// Builds a target using CTC. Slightly improved as follows:
41
// Includes normalizations and clipping for stability.
42
// labels should be pre-padded with nulls wherever desired, but they don't
43
// have to be between all labels. Allows for multi-label codes with no
44
// nulls between.
45
// labels can be longer than the time sequence, but the total number of
46
// essential labels (non-null plus nulls between equal labels) must not exceed
47
// the number of timesteps in outputs.
48
// outputs is the output of the network, and should have already been
49
// normalized with NormalizeProbs.
50
// On return targets is filled with the computed targets.
51
// Returns false if there is insufficient time for the labels.
52
static
bool
ComputeCTCTargets
(
const
GenericVector<int>
& truth_labels,
53
int
null_char,
54
const
GENERIC_2D_ARRAY<float>
& outputs,
55
NetworkIO
* targets);
56
57
private
:
58
// Constructor is private as the instance only holds information specific to
59
// the current labels, outputs etc, and is built by the static function.
60
CTC
(
const
GenericVector<int>
& labels,
int
null_char,
61
const
GENERIC_2D_ARRAY<float>
& outputs);
62
63
// Computes vectors of min and max label index for each timestep, based on
64
// whether skippability of nulls makes it possible to complete a valid path.
65
bool
ComputeLabelLimits();
66
// Computes targets based purely on the labels by spreading the labels evenly
67
// over the available timesteps.
68
void
ComputeSimpleTargets(
GENERIC_2D_ARRAY<float>
* targets)
const
;
69
// Computes mean positions and half widths of the simple targets by spreading
70
// the labels even over the available timesteps.
71
void
ComputeWidthsAndMeans(
GenericVector<float>
* half_widths,
72
GenericVector<int>
* means)
const
;
73
// Calculates and returns a suitable fraction of the simple targets to add
74
// to the network outputs.
75
float
CalculateBiasFraction();
76
// Runs the forward CTC pass, filling in log_probs.
77
void
Forward(
GENERIC_2D_ARRAY<double>
* log_probs)
const
;
78
// Runs the backward CTC pass, filling in log_probs.
79
void
Backward(
GENERIC_2D_ARRAY<double>
* log_probs)
const
;
80
// Normalizes and brings probs out of log space with a softmax over time.
81
void
NormalizeSequence(
GENERIC_2D_ARRAY<double>
* probs)
const
;
82
// For each timestep computes the max prob for each class over all
83
// instances of the class in the labels_, and sets the targets to
84
// the max observed prob.
85
void
LabelsToClasses(
const
GENERIC_2D_ARRAY<double>
& probs,
86
NetworkIO
* targets)
const
;
87
// Normalizes the probabilities such that no target has a prob below min_prob,
88
// and, provided that the initial total is at least min_total_prob, then all
89
// probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
90
// probability is thus 1 - (num_classes-1)*min_prob.
91
static
void
NormalizeProbs
(
GENERIC_2D_ARRAY<float>
* probs);
92
// Returns true if the label at index is a needed null.
93
bool
NeededNull(
int
index)
const
;
94
// Returns exp(clipped(x)), clipping x to a reasonable range to prevent over/
95
// underflow.
96
static
double
ClippedExp(
double
x) {
97
if
(x < -kMaxExpArg_)
return
exp(-kMaxExpArg_);
98
if
(x > kMaxExpArg_)
return
exp(kMaxExpArg_);
99
return
exp(x);
100
}
101
102
// Minimum probability limit for softmax input to ctc_loss.
103
static
const
float
kMinProb_;
104
// Maximum absolute argument to exp().
105
static
const
double
kMaxExpArg_;
106
// Minimum probability for total prob in time normalization.
107
static
const
double
kMinTotalTimeProb_;
108
// Minimum probability for total prob in final normalization.
109
static
const
double
kMinTotalFinalProb_;
110
111
// The truth label indices that are to be matched to outputs_.
112
const
GenericVector<int>
& labels_;
113
// The network outputs.
114
GENERIC_2D_ARRAY<float>
outputs_;
115
// The null or "blank" label.
116
int
null_char_;
117
// Number of timesteps in outputs_.
118
int
num_timesteps_;
119
// Number of classes in outputs_.
120
int
num_classes_;
121
// Number of labels in labels_.
122
int
num_labels_;
123
// Min and max valid label indices for each timestep.
124
GenericVector<int>
min_labels_;
125
GenericVector<int>
max_labels_;
126
};
127
128
}
// namespace tesseract
129
130
#endif // TESSERACT_LSTM_CTC_H_
networkio.h
network.h
GENERIC_2D_ARRAY< float >
genericvector.h
tesseract::NetworkIO
Definition:
networkio.h:39
tesseract::CTC
Definition:
ctc.h:30
tesseract
Definition:
baseapi.h:65
tesseract::NetworkIO::mutable_float_array
GENERIC_2D_ARRAY< float > * mutable_float_array()
Definition:
networkio.h:140
GenericVector< int >
tesseract::CTC::NormalizeProbs
static void NormalizeProbs(NetworkIO *probs)
Definition:
ctc.h:36
scrollview.h
tesseract::CTC::ComputeCTCTargets
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition:
ctc.cpp:54
src
training
ctc.h
Generated on Thu Jan 30 2020 14:22:21 for tesseract by
1.8.16