36 const float CTC::kMinProb_ = 1e-12;
38 const double CTC::kMaxExpArg_ = 80.0;
40 const double CTC::kMinTotalTimeProb_ = 1e-8;
42 const double CTC::kMinTotalFinalProb_ = 1e-6;
58 std::unique_ptr<CTC> ctc(
new CTC(labels, null_char, outputs));
59 if (!ctc->ComputeLabelLimits()) {
65 ctc->ComputeSimpleTargets(&simple_targets);
67 float bias_fraction = ctc->CalculateBiasFraction();
68 simple_targets *= bias_fraction;
69 ctc->outputs_ += simple_targets;
74 ctc->Forward(&log_alphas);
75 ctc->Backward(&log_betas);
77 log_alphas += log_betas;
78 ctc->NormalizeSequence(&log_alphas);
79 ctc->LabelsToClasses(log_alphas, targets);
86 : labels_(labels), outputs_(outputs), null_char_(null_char) {
87 num_timesteps_ = outputs.
dim1();
88 num_classes_ = outputs.
dim2();
89 num_labels_ = labels_.
size();
94 bool CTC::ComputeLabelLimits() {
97 int min_u = num_labels_ - 1;
98 if (labels_[min_u] == null_char_) --min_u;
99 for (
int t = num_timesteps_ - 1; t >= 0; --t) {
100 min_labels_[t] = min_u;
103 if (labels_[min_u] == null_char_ && min_u > 0 &&
104 labels_[min_u + 1] != labels_[min_u - 1]) {
109 int max_u = labels_[0] == null_char_;
110 for (
int t = 0; t < num_timesteps_; ++t) {
111 max_labels_[t] = max_u;
112 if (max_labels_[t] < min_labels_[t])
return false;
113 if (max_u + 1 < num_labels_) {
115 if (labels_[max_u] == null_char_ && max_u + 1 < num_labels_ &&
116 labels_[max_u + 1] != labels_[max_u - 1]) {
128 targets->
Resize(num_timesteps_, num_classes_, 0.0f);
131 ComputeWidthsAndMeans(&half_widths, &means);
132 for (
int l = 0; l < num_labels_; ++l) {
133 int label = labels_[l];
134 float left_half_width = half_widths[l];
135 float right_half_width = left_half_width;
137 if (label == null_char_) {
138 if (!NeededNull(l)) {
139 if ((l > 0 && mean == means[l - 1]) ||
140 (l + 1 < num_labels_ && mean == means[l + 1])) {
146 if (l > 0) left_half_width = mean - means[l - 1];
147 if (l + 1 < num_labels_) right_half_width = means[l + 1] - mean;
149 if (mean >= 0 && mean < num_timesteps_) targets->
put(mean, label, 1.0f);
150 for (
int offset = 1; offset < left_half_width && mean >= offset; ++offset) {
151 float prob = 1.0f - offset / left_half_width;
152 if (mean - offset < num_timesteps_ &&
153 prob > targets->
get(mean - offset, label)) {
154 targets->
put(mean - offset, label, prob);
158 offset < right_half_width && mean + offset < num_timesteps_;
160 float prob = 1.0f - offset / right_half_width;
161 if (mean + offset >= 0 && prob > targets->
get(mean + offset, label)) {
162 targets->
put(mean + offset, label, prob);
175 int num_plus = 0, num_star = 0;
176 for (
int i = 0; i < num_labels_; ++i) {
177 if (labels_[i] != null_char_ || NeededNull(i))
185 float plus_size = 1.0f, star_size = 0.0f;
186 float total_floating = num_plus + num_star;
187 if (total_floating <= num_timesteps_) {
188 plus_size = star_size = num_timesteps_ / total_floating;
189 }
else if (num_star > 0) {
190 star_size =
static_cast<float>(num_timesteps_ - num_plus) / num_star;
193 float mean_pos = 0.0f;
194 for (
int i = 0; i < num_labels_; ++i) {
196 if (labels_[i] != null_char_ || NeededNull(i)) {
197 half_width = plus_size / 2.0f;
199 half_width = star_size / 2.0f;
201 mean_pos += half_width;
202 means->
push_back(static_cast<int>(mean_pos));
203 mean_pos += half_width;
211 int num_classes = outputs.
dim2();
212 const float* outputs_t = outputs[t];
213 for (
int c = 1; c < num_classes; ++c) {
214 if (outputs_t[c] > outputs_t[result]) result = c;
221 float CTC::CalculateBiasFraction() {
224 for (
int t = 0; t < num_timesteps_; ++t) {
225 int label = BestLabel(outputs_, t);
226 while (t + 1 < num_timesteps_ && BestLabel(outputs_, t + 1) == label) ++t;
227 if (label != null_char_) output_labels.
push_back(label);
232 for (
int l = 0; l < num_labels_; ++l) {
233 ++truth_counts[labels_[l]];
235 for (
int l = 0; l < output_labels.
size(); ++l) {
236 ++output_counts[output_labels[l]];
239 int true_pos = 0, false_pos = 0, total_labels = 0;
240 for (
int c = 0; c < num_classes_; ++c) {
241 if (c == null_char_)
continue;
242 int truth_count = truth_counts[c];
243 int ocr_count = output_counts[c];
244 if (truth_count > 0) {
245 total_labels += truth_count;
246 if (ocr_count > truth_count) {
247 true_pos += truth_count;
248 false_pos += ocr_count - truth_count;
250 true_pos += ocr_count;
256 if (total_labels == 0)
return 0.0f;
257 return exp(std::max(true_pos - false_pos, 1) * log(kMinProb_) / total_labels);
263 static double LogSumExp(
double ln_x,
double ln_y) {
265 return ln_x + log1p(exp(ln_y - ln_x));
267 return ln_y + log1p(exp(ln_x - ln_y));
273 log_probs->
Resize(num_timesteps_, num_labels_, -FLT_MAX);
274 log_probs->
put(0, 0, log(outputs_(0, labels_[0])));
275 if (labels_[0] == null_char_)
276 log_probs->
put(0, 1, log(outputs_(0, labels_[1])));
277 for (
int t = 1; t < num_timesteps_; ++t) {
278 const float* outputs_t = outputs_[t];
279 for (
int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
281 double log_sum = log_probs->
get(t - 1, u);
284 log_sum = LogSumExp(log_sum, log_probs->
get(t - 1, u - 1));
287 if (u >= 2 && labels_[u - 1] == null_char_ &&
288 labels_[u] != labels_[u - 2]) {
289 log_sum = LogSumExp(log_sum, log_probs->
get(t - 1, u - 2));
292 double label_prob = outputs_t[labels_[u]];
293 log_sum += log(label_prob);
294 log_probs->
put(t, u, log_sum);
301 log_probs->
Resize(num_timesteps_, num_labels_, -FLT_MAX);
302 log_probs->
put(num_timesteps_ - 1, num_labels_ - 1, 0.0);
303 if (labels_[num_labels_ - 1] == null_char_)
304 log_probs->
put(num_timesteps_ - 1, num_labels_ - 2, 0.0);
305 for (
int t = num_timesteps_ - 2; t >= 0; --t) {
306 const float* outputs_tp1 = outputs_[t + 1];
307 for (
int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
309 double log_sum = log_probs->
get(t + 1, u) + log(outputs_tp1[labels_[u]]);
311 if (u + 1 < num_labels_) {
312 double prev_prob = outputs_tp1[labels_[u + 1]];
314 LogSumExp(log_sum, log_probs->
get(t + 1, u + 1) + log(prev_prob));
317 if (u + 2 < num_labels_ && labels_[u + 1] == null_char_ &&
318 labels_[u] != labels_[u + 2]) {
319 double skip_prob = outputs_tp1[labels_[u + 2]];
321 LogSumExp(log_sum, log_probs->
get(t + 1, u + 2) + log(skip_prob));
323 log_probs->
put(t, u, log_sum);
330 double max_logprob = probs->
Max();
331 for (
int u = 0; u < num_labels_; ++u) {
333 for (
int t = 0; t < num_timesteps_; ++t) {
335 double prob = probs->
get(t, u);
337 prob = ClippedExp(prob - max_logprob);
341 probs->
put(t, u, prob);
346 if (total < kMinTotalTimeProb_) total = kMinTotalTimeProb_;
347 for (
int t = 0; t < num_timesteps_; ++t)
348 probs->
put(t, u, probs->
get(t, u) / total);
356 NetworkIO* targets)
const {
360 for (
int t = 0; t < num_timesteps_; ++t) {
361 float* targets_t = targets->f(t);
363 for (
int u = 0; u < num_labels_; ++u) {
364 double prob = probs(t, u);
368 if (prob > class_probs[labels_[u]]) class_probs[labels_[u]] = prob;
372 for (
int c = 0; c < num_classes_; ++c) {
373 targets_t[c] = class_probs[c];
374 if (class_probs[c] > class_probs[best_class]) best_class = c;
385 int num_timesteps = probs->
dim1();
386 int num_classes = probs->
dim2();
387 for (
int t = 0; t < num_timesteps; ++t) {
388 float* probs_t = (*probs)[t];
391 for (
int c = 0; c < num_classes; ++c) total += probs_t[c];
392 if (total < kMinTotalFinalProb_) total = kMinTotalFinalProb_;
394 double increment = 0.0;
395 for (
int c = 0; c < num_classes; ++c) {
396 double prob = probs_t[c] / total;
397 if (prob < kMinProb_) increment += kMinProb_ - prob;
401 for (
int c = 0; c < num_classes; ++c) {
402 float prob = probs_t[c] / total;
403 probs_t[c] = std::max(prob, kMinProb_);
409 bool CTC::NeededNull(
int index)
const {
410 return labels_[index] == null_char_ && index > 0 && index + 1 < num_labels_ &&
411 labels_[index + 1] == labels_[index - 1];
static void NormalizeProbs(NetworkIO *probs)
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
void init_to_size(int size, const T &t)
void put(ICOORD pos, const T &thing)
void Resize(int size1, int size2, const T &empty)