tesseract  4.00.00dev
ctc.h
Go to the documentation of this file.
1 // File: ctc.h
3 // Description: Slightly improved standard CTC to compute the targets.
4 // Author: Ray Smith
5 // Created: Wed Jul 13 15:17:06 PDT 2016
6 //
7 // (C) Copyright 2016, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_CTC_H_
20 #define TESSERACT_LSTM_CTC_H_
21 
22 #include "genericvector.h"
23 #include "network.h"
24 #include "networkio.h"
25 #include "scrollview.h"
26 
27 namespace tesseract {
28 
29 // Class to encapsulate CTC and simple target generation.
30 class CTC {
31  public:
32  // Normalizes the probabilities such that no target has a prob below min_prob,
33  // and, provided that the initial total is at least min_total_prob, then all
34  // probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
35  // probability is thus 1 - (num_classes-1)*min_prob.
36  static void NormalizeProbs(NetworkIO* probs) {
38  }
39 
40  // Builds a target using CTC. Slightly improved as follows:
41  // Includes normalizations and clipping for stability.
42  // labels should be pre-padded with nulls wherever desired, but they don't
43  // have to be between all labels. Allows for multi-label codes with no
44  // nulls between.
45  // labels can be longer than the time sequence, but the total number of
46  // essential labels (non-null plus nulls between equal labels) must not exceed
47  // the number of timesteps in outputs.
48  // outputs is the output of the network, and should have already been
49  // normalized with NormalizeProbs.
50  // On return targets is filled with the computed targets.
51  // Returns false if there is insufficient time for the labels.
52  static bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
53  int null_char,
54  const GENERIC_2D_ARRAY<float>& outputs,
55  NetworkIO* targets);
56 
57  private:
58  // Constructor is private as the instance only holds information specific to
59  // the current labels, outputs etc, and is built by the static function.
60  CTC(const GenericVector<int>& labels, int null_char,
61  const GENERIC_2D_ARRAY<float>& outputs);
62 
63  // Computes vectors of min and max label index for each timestep, based on
64  // whether skippability of nulls makes it possible to complete a valid path.
65  bool ComputeLabelLimits();
66  // Computes targets based purely on the labels by spreading the labels evenly
67  // over the available timesteps.
68  void ComputeSimpleTargets(GENERIC_2D_ARRAY<float>* targets) const;
69  // Computes mean positions and half widths of the simple targets by spreading
70  // the labels even over the available timesteps.
71  void ComputeWidthsAndMeans(GenericVector<float>* half_widths,
72  GenericVector<int>* means) const;
73  // Calculates and returns a suitable fraction of the simple targets to add
74  // to the network outputs.
75  float CalculateBiasFraction();
76  // Runs the forward CTC pass, filling in log_probs.
77  void Forward(GENERIC_2D_ARRAY<double>* log_probs) const;
78  // Runs the backward CTC pass, filling in log_probs.
79  void Backward(GENERIC_2D_ARRAY<double>* log_probs) const;
80  // Normalizes and brings probs out of log space with a softmax over time.
81  void NormalizeSequence(GENERIC_2D_ARRAY<double>* probs) const;
82  // For each timestep computes the max prob for each class over all
83  // instances of the class in the labels_, and sets the targets to
84  // the max observed prob.
85  void LabelsToClasses(const GENERIC_2D_ARRAY<double>& probs,
86  NetworkIO* targets) const;
87  // Normalizes the probabilities such that no target has a prob below min_prob,
88  // and, provided that the initial total is at least min_total_prob, then all
89  // probs will sum to 1, otherwise to sum/min_total_prob. The maximum output
90  // probability is thus 1 - (num_classes-1)*min_prob.
91  static void NormalizeProbs(GENERIC_2D_ARRAY<float>* probs);
92  // Returns true if the label at index is a needed null.
93  bool NeededNull(int index) const;
94  // Returns exp(clipped(x)), clipping x to a reasonable range to prevent over/
95  // underflow.
96  static double ClippedExp(double x) {
97  if (x < -kMaxExpArg_) return exp(-kMaxExpArg_);
98  if (x > kMaxExpArg_) return exp(kMaxExpArg_);
99  return exp(x);
100  }
101 
102  // Minimum probability limit for softmax input to ctc_loss.
103  static const float kMinProb_;
104  // Maximum absolute argument to exp().
105  static const double kMaxExpArg_;
106  // Minimum probability for total prob in time normalization.
107  static const double kMinTotalTimeProb_;
108  // Minimum probability for total prob in final normalization.
109  static const double kMinTotalFinalProb_;
110 
111  // The truth label indices that are to be matched to outputs_.
112  const GenericVector<int>& labels_;
113  // The network outputs.
114  GENERIC_2D_ARRAY<float> outputs_;
115  // The null or "blank" label.
116  int null_char_;
117  // Number of timesteps in outputs_.
118  int num_timesteps_;
119  // Number of classes in outputs_.
120  int num_classes_;
121  // Number of labels in labels_.
122  int num_labels_;
123  // Min and max valid label indices for each timestep.
124  GenericVector<int> min_labels_;
125  GenericVector<int> max_labels_;
126 };
127 
128 } // namespace tesseract
129 
130 #endif // TESSERACT_LSTM_CTC_H_
GENERIC_2D_ARRAY< float > * mutable_float_array()
Definition: networkio.h:140
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:53