tesseract  4.00.00dev
ctc.cpp
Go to the documentation of this file.
1 // File: ctc.cpp
3 // Description: Slightly improved standard CTC to compute the targets.
4 // Author: Ray Smith
5 // Created: Wed Jul 13 15:50:06 PDT 2016
6 //
7 // (C) Copyright 2016, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 #include "ctc.h"
19 
20 #include <memory>
21 
22 #include "genericvector.h"
23 #include "host.h"
24 #include "matrix.h"
25 #include "networkio.h"
26 
27 #include "network.h"
28 #include "scrollview.h"
29 
30 namespace tesseract {
31 
32 // Magic constants that keep CTC stable.
33 // Minimum probability limit for softmax input to ctc_loss.
34 const float CTC::kMinProb_ = 1e-12;
35 // Maximum absolute argument to exp().
36 const double CTC::kMaxExpArg_ = 80.0;
37 // Minimum probability for total prob in time normalization.
38 const double CTC::kMinTotalTimeProb_ = 1e-8;
39 // Minimum probability for total prob in final normalization.
40 const double CTC::kMinTotalFinalProb_ = 1e-6;
41 
42 // Builds a target using CTC. Slightly improved as follows:
43 // Includes normalizations and clipping for stability.
44 // labels should be pre-padded with nulls everywhere.
45 // labels can be longer than the time sequence, but the total number of
46 // essential labels (non-null plus nulls between equal labels) must not exceed
47 // the number of timesteps in outputs.
48 // outputs is the output of the network, and should have already been
49 // normalized with NormalizeProbs.
50 // On return targets is filled with the computed targets.
51 // Returns false if there is insufficient time for the labels.
52 /* static */
53 bool CTC::ComputeCTCTargets(const GenericVector<int>& labels, int null_char,
54  const GENERIC_2D_ARRAY<float>& outputs,
55  NetworkIO* targets) {
56  std::unique_ptr<CTC> ctc(new CTC(labels, null_char, outputs));
57  if (!ctc->ComputeLabelLimits()) {
58  return false; // Not enough time.
59  }
60  // Generate simple targets purely from the truth labels by spreading them
61  // evenly over time.
62  GENERIC_2D_ARRAY<float> simple_targets;
63  ctc->ComputeSimpleTargets(&simple_targets);
64  // Add the simple targets as a starter bias to the network outputs.
65  float bias_fraction = ctc->CalculateBiasFraction();
66  simple_targets *= bias_fraction;
67  ctc->outputs_ += simple_targets;
68  NormalizeProbs(&ctc->outputs_);
69  // Run regular CTC on the biased outputs.
70  // Run forward and backward
71  GENERIC_2D_ARRAY<double> log_alphas, log_betas;
72  ctc->Forward(&log_alphas);
73  ctc->Backward(&log_betas);
74  // Normalize and come out of log space with a clipped softmax over time.
75  log_alphas += log_betas;
76  ctc->NormalizeSequence(&log_alphas);
77  ctc->LabelsToClasses(log_alphas, targets);
78  NormalizeProbs(targets);
79  return true;
80 }
81 
82 CTC::CTC(const GenericVector<int>& labels, int null_char,
83  const GENERIC_2D_ARRAY<float>& outputs)
84  : labels_(labels), outputs_(outputs), null_char_(null_char) {
85  num_timesteps_ = outputs.dim1();
86  num_classes_ = outputs.dim2();
87  num_labels_ = labels_.size();
88 }
89 
90 // Computes vectors of min and max label index for each timestep, based on
91 // whether skippability of nulls makes it possible to complete a valid path.
92 bool CTC::ComputeLabelLimits() {
93  min_labels_.init_to_size(num_timesteps_, 0);
94  max_labels_.init_to_size(num_timesteps_, 0);
95  int min_u = num_labels_ - 1;
96  if (labels_[min_u] == null_char_) --min_u;
97  for (int t = num_timesteps_ - 1; t >= 0; --t) {
98  min_labels_[t] = min_u;
99  if (min_u > 0) {
100  --min_u;
101  if (labels_[min_u] == null_char_ && min_u > 0 &&
102  labels_[min_u + 1] != labels_[min_u - 1]) {
103  --min_u;
104  }
105  }
106  }
107  int max_u = labels_[0] == null_char_;
108  for (int t = 0; t < num_timesteps_; ++t) {
109  max_labels_[t] = max_u;
110  if (max_labels_[t] < min_labels_[t]) return false; // Not enough room.
111  if (max_u + 1 < num_labels_) {
112  ++max_u;
113  if (labels_[max_u] == null_char_ && max_u + 1 < num_labels_ &&
114  labels_[max_u + 1] != labels_[max_u - 1]) {
115  ++max_u;
116  }
117  }
118  }
119  return true;
120 }
121 
122 // Computes targets based purely on the labels by spreading the labels evenly
123 // over the available timesteps.
124 void CTC::ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const {
125  // Initialize all targets to zero.
126  targets->Resize(num_timesteps_, num_classes_, 0.0f);
127  GenericVector<float> half_widths;
128  GenericVector<int> means;
129  ComputeWidthsAndMeans(&half_widths, &means);
130  for (int l = 0; l < num_labels_; ++l) {
131  int label = labels_[l];
132  float left_half_width = half_widths[l];
133  float right_half_width = left_half_width;
134  int mean = means[l];
135  if (label == null_char_) {
136  if (!NeededNull(l)) {
137  if ((l > 0 && mean == means[l - 1]) ||
138  (l + 1 < num_labels_ && mean == means[l + 1])) {
139  continue; // Drop overlapping null.
140  }
141  }
142  // Make sure that no space is left unoccupied and that non-nulls always
143  // peak at 1 by stretching nulls to meet their neighbors.
144  if (l > 0) left_half_width = mean - means[l - 1];
145  if (l + 1 < num_labels_) right_half_width = means[l + 1] - mean;
146  }
147  if (mean >= 0 && mean < num_timesteps_) targets->put(mean, label, 1.0f);
149  float prob = 1.0f - offset / left_half_width;
150  if (mean - offset < num_timesteps_ &&
151  prob > targets->get(mean - offset, label)) {
152  targets->put(mean - offset, label, prob);
153  }
154  }
155  for (int offset = 1;
156  offset < right_half_width && mean + offset < num_timesteps_;
157  ++offset) {
158  float prob = 1.0f - offset / right_half_width;
159  if (mean + offset >= 0 && prob > targets->get(mean + offset, label)) {
160  targets->put(mean + offset, label, prob);
161  }
162  }
163  }
164 }
165 
166 // Computes mean positions and half widths of the simple targets by spreading
167 // the labels evenly over the available timesteps.
168 void CTC::ComputeWidthsAndMeans(GenericVector<float>* half_widths,
169  GenericVector<int>* means) const {
170  // Count the number of labels of each type, in regexp terms, counts plus
171  // (non-null or necessary null, which must occur at least once) and star
172  // (optional null).
173  int num_plus = 0, num_star = 0;
174  for (int i = 0; i < num_labels_; ++i) {
175  if (labels_[i] != null_char_ || NeededNull(i))
176  ++num_plus;
177  else
178  ++num_star;
179  }
180  // Compute the size for each type. If there is enough space for everything
181  // to have size>=1, then all are equal, otherwise plus_size=1 and star gets
182  // whatever is left-over.
183  float plus_size = 1.0f, star_size = 0.0f;
184  float total_floating = num_plus + num_star;
185  if (total_floating <= num_timesteps_) {
186  plus_size = star_size = num_timesteps_ / total_floating;
187  } else if (num_star > 0) {
188  star_size = static_cast<float>(num_timesteps_ - num_plus) / num_star;
189  }
190  // Set the width and compute the mean of each.
191  float mean_pos = 0.0f;
192  for (int i = 0; i < num_labels_; ++i) {
193  float half_width;
194  if (labels_[i] != null_char_ || NeededNull(i)) {
195  half_width = plus_size / 2.0f;
196  } else {
197  half_width = star_size / 2.0f;
198  }
199  mean_pos += half_width;
200  means->push_back(static_cast<int>(mean_pos));
201  mean_pos += half_width;
202  half_widths->push_back(half_width);
203  }
204 }
205 
206 // Helper returns the index of the highest probability label at timestep t.
207 static int BestLabel(const GENERIC_2D_ARRAY<float>& outputs, int t) {
208  int result = 0;
209  int num_classes = outputs.dim2();
210  const float* outputs_t = outputs[t];
211  for (int c = 1; c < num_classes; ++c) {
212  if (outputs_t[c] > outputs_t[result]) result = c;
213  }
214  return result;
215 }
216 
217 // Calculates and returns a suitable fraction of the simple targets to add
218 // to the network outputs.
219 float CTC::CalculateBiasFraction() {
220  // Compute output labels via basic decoding.
221  GenericVector<int> output_labels;
222  for (int t = 0; t < num_timesteps_; ++t) {
223  int label = BestLabel(outputs_, t);
224  while (t + 1 < num_timesteps_ && BestLabel(outputs_, t + 1) == label) ++t;
225  if (label != null_char_) output_labels.push_back(label);
226  }
227  // Simple bag of labels error calculation.
228  GenericVector<int> truth_counts(num_classes_, 0);
229  GenericVector<int> output_counts(num_classes_, 0);
230  for (int l = 0; l < num_labels_; ++l) {
231  ++truth_counts[labels_[l]];
232  }
233  for (int l = 0; l < output_labels.size(); ++l) {
234  ++output_counts[output_labels[l]];
235  }
236  // Count the number of true and false positive non-nulls and truth labels.
237  int true_pos = 0, false_pos = 0, total_labels = 0;
238  for (int c = 0; c < num_classes_; ++c) {
239  if (c == null_char_) continue;
240  int truth_count = truth_counts[c];
241  int ocr_count = output_counts[c];
242  if (truth_count > 0) {
243  total_labels += truth_count;
244  if (ocr_count > truth_count) {
245  true_pos += truth_count;
246  false_pos += ocr_count - truth_count;
247  } else {
248  true_pos += ocr_count;
249  }
250  }
251  // We don't need to count classes that don't exist in the truth as
252  // false positives, because they don't affect CTC at all.
253  }
254  if (total_labels == 0) return 0.0f;
255  return exp(MAX(true_pos - false_pos, 1) * log(kMinProb_) / total_labels);
256 }
257 
258 // Given ln(x) and ln(y), returns ln(x + y), using:
259 // ln(x + y) = ln(y) + ln(1 + exp(ln(y) - ln(x)), ensuring that ln(x) is the
260 // bigger number to maximize precision.
261 static double LogSumExp(double ln_x, double ln_y) {
262  if (ln_x >= ln_y) {
263  return ln_x + log1p(exp(ln_y - ln_x));
264  } else {
265  return ln_y + log1p(exp(ln_x - ln_y));
266  }
267 }
268 
269 // Runs the forward CTC pass, filling in log_probs.
270 void CTC::Forward(GENERIC_2D_ARRAY<double>* log_probs) const {
271  log_probs->Resize(num_timesteps_, num_labels_, -MAX_FLOAT32);
272  log_probs->put(0, 0, log(outputs_(0, labels_[0])));
273  if (labels_[0] == null_char_)
274  log_probs->put(0, 1, log(outputs_(0, labels_[1])));
275  for (int t = 1; t < num_timesteps_; ++t) {
276  const float* outputs_t = outputs_[t];
277  for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
278  // Continuing the same label.
279  double log_sum = log_probs->get(t - 1, u);
280  // Change from previous label.
281  if (u > 0) {
282  log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 1));
283  }
284  // Skip the null if allowed.
285  if (u >= 2 && labels_[u - 1] == null_char_ &&
286  labels_[u] != labels_[u - 2]) {
287  log_sum = LogSumExp(log_sum, log_probs->get(t - 1, u - 2));
288  }
289  // Add in the log prob of the current label.
290  double label_prob = outputs_t[labels_[u]];
291  log_sum += log(label_prob);
292  log_probs->put(t, u, log_sum);
293  }
294  }
295 }
296 
297 // Runs the backward CTC pass, filling in log_probs.
298 void CTC::Backward(GENERIC_2D_ARRAY<double>* log_probs) const {
299  log_probs->Resize(num_timesteps_, num_labels_, -MAX_FLOAT32);
300  log_probs->put(num_timesteps_ - 1, num_labels_ - 1, 0.0);
301  if (labels_[num_labels_ - 1] == null_char_)
302  log_probs->put(num_timesteps_ - 1, num_labels_ - 2, 0.0);
303  for (int t = num_timesteps_ - 2; t >= 0; --t) {
304  const float* outputs_tp1 = outputs_[t + 1];
305  for (int u = min_labels_[t]; u <= max_labels_[t]; ++u) {
306  // Continuing the same label.
307  double log_sum = log_probs->get(t + 1, u) + log(outputs_tp1[labels_[u]]);
308  // Change from previous label.
309  if (u + 1 < num_labels_) {
310  double prev_prob = outputs_tp1[labels_[u + 1]];
311  log_sum =
312  LogSumExp(log_sum, log_probs->get(t + 1, u + 1) + log(prev_prob));
313  }
314  // Skip the null if allowed.
315  if (u + 2 < num_labels_ && labels_[u + 1] == null_char_ &&
316  labels_[u] != labels_[u + 2]) {
317  double skip_prob = outputs_tp1[labels_[u + 2]];
318  log_sum =
319  LogSumExp(log_sum, log_probs->get(t + 1, u + 2) + log(skip_prob));
320  }
321  log_probs->put(t, u, log_sum);
322  }
323  }
324 }
325 
326 // Normalizes and brings probs out of log space with a softmax over time.
327 void CTC::NormalizeSequence(GENERIC_2D_ARRAY<double>* probs) const {
328  double max_logprob = probs->Max();
329  for (int u = 0; u < num_labels_; ++u) {
330  double total = 0.0;
331  for (int t = 0; t < num_timesteps_; ++t) {
332  // Separate impossible path from unlikely probs.
333  double prob = probs->get(t, u);
334  if (prob > -MAX_FLOAT32)
335  prob = ClippedExp(prob - max_logprob);
336  else
337  prob = 0.0;
338  total += prob;
339  probs->put(t, u, prob);
340  }
341  // Note that although this is a probability distribution over time and
342  // therefore should sum to 1, it is important to allow some labels to be
343  // all zero, (or at least tiny) as it is necessary to skip some blanks.
344  if (total < kMinTotalTimeProb_) total = kMinTotalTimeProb_;
345  for (int t = 0; t < num_timesteps_; ++t)
346  probs->put(t, u, probs->get(t, u) / total);
347  }
348 }
349 
350 // For each timestep computes the max prob for each class over all
351 // instances of the class in the labels_, and sets the targets to
352 // the max observed prob.
353 void CTC::LabelsToClasses(const GENERIC_2D_ARRAY<double>& probs,
354  NetworkIO* targets) const {
355  // For each timestep compute the max prob for each class over all
356  // instances of the class in the labels_.
357  GenericVector<double> class_probs;
358  for (int t = 0; t < num_timesteps_; ++t) {
359  float* targets_t = targets->f(t);
360  class_probs.init_to_size(num_classes_, 0.0);
361  for (int u = 0; u < num_labels_; ++u) {
362  double prob = probs(t, u);
363  // Note that although Graves specifies sum over all labels of the same
364  // class, we need to allow skipped blanks to go to zero, so they don't
365  // interfere with the non-blanks, so max is better than sum.
366  if (prob > class_probs[labels_[u]]) class_probs[labels_[u]] = prob;
367  // class_probs[labels_[u]] += prob;
368  }
369  int best_class = 0;
370  for (int c = 0; c < num_classes_; ++c) {
371  targets_t[c] = class_probs[c];
372  if (class_probs[c] > class_probs[best_class]) best_class = c;
373  }
374  }
375 }
376 
377 // Normalizes the probabilities such that no target has a prob below min_prob,
378 // and, provided that the initial total is at least min_total_prob, then all
379 // probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
380 // probability is thus 1 - (num_classes-1)*min_prob.
381 /* static */
383  int num_timesteps = probs->dim1();
384  int num_classes = probs->dim2();
385  for (int t = 0; t < num_timesteps; ++t) {
386  float* probs_t = (*probs)[t];
387  // Compute the total and clip that to prevent amplification of noise.
388  double total = 0.0;
389  for (int c = 0; c < num_classes; ++c) total += probs_t[c];
390  if (total < kMinTotalFinalProb_) total = kMinTotalFinalProb_;
391  // Compute the increased total as a result of clipping.
392  double increment = 0.0;
393  for (int c = 0; c < num_classes; ++c) {
394  double prob = probs_t[c] / total;
395  if (prob < kMinProb_) increment += kMinProb_ - prob;
396  }
397  // Now normalize with clipping. Any additional clipping is negligible.
398  total += increment;
399  for (int c = 0; c < num_classes; ++c) {
400  float prob = probs_t[c] / total;
401  probs_t[c] = MAX(prob, kMinProb_);
402  }
403  }
404 }
405 
406 // Returns true if the label at index is a needed null.
407 bool CTC::NeededNull(int index) const {
408  return labels_[index] == null_char_ && index > 0 && index + 1 < num_labels_ &&
409  labels_[index + 1] == labels_[index - 1];
410 }
411 
412 } // namespace tesseract
double u[max]
float * f(int t)
Definition: networkio.h:115
void init_to_size(int size, T t)
T get(ICOORD pos) const
Definition: matrix.h:223
int push_back(T object)
void Resize(int size1, int size2, const T &empty)
Definition: matrix.h:98
voidpf uLong offset
Definition: ioapi.h:42
int size() const
Definition: genericvector.h:72
int dim1() const
Definition: matrix.h:201
int dim2() const
Definition: matrix.h:202
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36
#define MAX_FLOAT32
Definition: host.h:66
#define MAX(x, y)
Definition: ndminx.h:24
void put(ICOORD pos, const T &thing)
Definition: matrix.h:215
T Max() const
Definition: matrix.h:337
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:53