tesseract  4.00.00dev
lstmrecognizer.h
Go to the documentation of this file.
1 // File: lstmrecognizer.h
3 // Description: Top-level line recognizer class for LSTM-based networks.
4 // Author: Ray Smith
5 // Created: Thu May 02 08:57:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
20 #define TESSERACT_LSTM_LSTMRECOGNIZER_H_
21 
22 #include "ccutil.h"
23 #include "helpers.h"
24 #include "imagedata.h"
25 #include "matrix.h"
26 #include "network.h"
27 #include "networkscratch.h"
28 #include "recodebeam.h"
29 #include "series.h"
30 #include "strngs.h"
31 #include "unicharcompress.h"
32 
33 class BLOB_CHOICE_IT;
34 struct Pix;
35 class ROW_RES;
36 class ScrollView;
37 class TBOX;
38 class WERD_RES;
39 
40 namespace tesseract {
41 
42 class Dict;
43 class ImageData;
44 
45 // Enum indicating training mode control flags.
51 };
52 
53 // Top-level line recognizer class for LSTM-based networks.
54 // Note that a sub-class, LSTMTrainer is used for training.
56  public:
59 
60  int NumOutputs() const {
61  return network_->NumOutputs();
62  }
63  int training_iteration() const {
64  return training_iteration_;
65  }
66  int sample_iteration() const {
67  return sample_iteration_;
68  }
69  double learning_rate() const {
70  return learning_rate_;
71  }
72  bool IsHardening() const {
73  return (training_flags_ & TF_AUTO_HARDEN) != 0;
74  }
76  if (network_ == nullptr) return LT_NONE;
77  StaticShape shape;
78  shape = network_->OutputShape(shape);
79  return shape.loss_type();
80  }
81  bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
82  bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
83  // True if recoder_ is active to re-encode text to a smaller space.
84  bool IsRecoding() const {
86  }
87  // Returns the cache strategy for the DocumentCache.
90  : CS_SEQUENTIAL;
91  }
92  // Returns true if the network is a TensorFlow network.
93  bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
94  // Returns a vector of layer ids that can be passed to other layer functions
95  // to access a specific layer.
97  ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
98  Series* series = static_cast<Series*>(network_);
99  GenericVector<STRING> layers;
100  series->EnumerateLayers(NULL, &layers);
101  return layers;
102  }
103  // Returns a specific layer from its id (from EnumerateLayers).
104  Network* GetLayer(const STRING& id) const {
105  ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
106  ASSERT_HOST(id.length() > 1 && id[0] == ':');
107  Series* series = static_cast<Series*>(network_);
108  return series->GetLayer(&id[1]);
109  }
110  // Returns the learning rate of the layer from its id.
111  float GetLayerLearningRate(const STRING& id) const {
112  ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
114  ASSERT_HOST(id.length() > 1 && id[0] == ':');
115  Series* series = static_cast<Series*>(network_);
116  return series->LayerLearningRate(&id[1]);
117  } else {
118  return learning_rate_;
119  }
120  }
121  // Multiplies the all the learning rate(s) by the given factor.
122  void ScaleLearningRate(double factor) {
123  ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
124  learning_rate_ *= factor;
127  for (int i = 0; i < layers.size(); ++i) {
128  ScaleLayerLearningRate(layers[i], factor);
129  }
130  }
131  }
132  // Multiplies the learning rate of the layer with id, by the given factor.
133  void ScaleLayerLearningRate(const STRING& id, double factor) {
134  ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
135  ASSERT_HOST(id.length() > 1 && id[0] == ':');
136  Series* series = static_cast<Series*>(network_);
137  series->ScaleLayerLearningRate(&id[1], factor);
138  }
139 
140  // True if the network is using adagrad to train.
141  bool IsUsingAdaGrad() const { return network_->TestFlag(NF_ADA_GRAD); }
142  // Provides access to the UNICHARSET that this classifier works with.
143  const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
144  // Provides access to the Dict that this classifier works with.
145  const Dict* GetDict() const { return dict_; }
146  // Sets the sample iteration to the given value. The sample_iteration_
147  // determines the seed for the random number generator. The training
148  // iteration is incremented only by a successful training iteration.
149  void SetIteration(int iteration) {
150  sample_iteration_ = iteration;
151  }
152  // Accessors for textline image normalization.
153  int NumInputs() const {
154  return network_->NumInputs();
155  }
156  int null_char() const { return null_char_; }
157 
158  // Writes to the given file. Returns false in case of error.
159  bool Serialize(TFile* fp) const;
160  // Reads from the given file. Returns false in case of error.
161  bool DeSerialize(TFile* fp);
162  // Loads the dictionary if possible from the traineddata file.
163  // Prints a warning message, and returns false but otherwise fails silently
164  // and continues to work without it if loading fails.
165  // Note that dictionary load is independent from DeSerialize, but dependent
166  // on the unicharset matching. This enables training to deserialize a model
167  // from checkpoint or restore without having to go back and reload the
168  // dictionary.
169  bool LoadDictionary(const char* lang, TessdataManager* mgr);
170 
171  // Recognizes the line image, contained within image_data, returning the
172  // ratings matrix and matching box_word for each WERD_RES in the output.
173  // If invert, tries inverted as well if the normal interpretation doesn't
174  // produce a good enough result. If use_alternates, the ratings matrix is
175  // filled with segmentation and classifier alternatives that may be searched
176  // using the standard beam search, otherwise, just a diagonal and prebuilt
177  // best_choice. The line_box is used for computing the box_word in the
178  // output words. Score_ratio is used to determine the classifier alternates.
179  // If one_word, then a single WERD_RES is formed, regardless of the spaces
180  // found during recognition.
181  // If not NULL, we attempt to translate the output to target_unicharset, but
182  // do not guarantee success, due to mismatches. In that case the output words
183  // are marked with our UNICHARSET, not the caller's.
184  void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
185  double worst_dict_cert, bool use_alternates,
186  const UNICHARSET* target_unicharset, const TBOX& line_box,
187  float score_ratio, bool one_word,
188  PointerVector<WERD_RES>* words);
189  // Builds a set of tesseract-compatible WERD_RESs aligned to line_box,
190  // corresponding to the network output in outputs, labels, label_coords.
191  // one_word generates a single word output, that may include spaces inside.
192  // use_alternates generates alternative BLOB_CHOICEs and segmentation paths,
193  // with cut-offs determined by scale_factor.
194  // If not NULL, we attempt to translate the output to target_unicharset, but
195  // do not guarantee success, due to mismatches. In that case the output words
196  // are marked with our UNICHARSET, not the caller's.
197  void WordsFromOutputs(const NetworkIO& outputs,
198  const GenericVector<int>& labels,
199  const GenericVector<int> label_coords,
200  const TBOX& line_box, bool debug, bool use_alternates,
201  bool one_word, float score_ratio, float scale_factor,
202  const UNICHARSET* target_unicharset,
203  PointerVector<WERD_RES>* words);
204 
205  // Helper computes min and mean best results in the output.
206  void OutputStats(const NetworkIO& outputs,
207  float* min_output, float* mean_output, float* sd);
208  // Recognizes the image_data, returning the labels,
209  // scores, and corresponding pairs of start, end x-coords in coords.
210  // If label_threshold is positive, uses it for making the labels, otherwise
211  // uses standard ctc. Returned in scale_factor is the reduction factor
212  // between the image and the output coords, for computing bounding boxes.
213  // If re_invert is true, the input is inverted back to its original
214  // photometric interpretation if inversion is attempted but fails to
215  // improve the results. This ensures that outputs contains the correct
216  // forward outputs for the best photometric interpretation.
217  // inputs is filled with the used inputs to the network, and if not null,
218  // target boxes is filled with scaled truth boxes if present in image_data.
219  bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
220  bool re_invert, float label_threshold, float* scale_factor,
221  NetworkIO* inputs, NetworkIO* outputs);
222  // Returns a tesseract-compatible WERD_RES from the line recognizer outputs.
223  // line_box should be the bounding box of the line image in the main image,
224  // outputs the output of the network,
225  // [word_start, word_end) the interval over which to convert,
226  // score_ratio for choosing alternate classifier choices,
227  // use_alternates to control generation of alternative segmentations,
228  // labels, label_coords, scale_factor from RecognizeLine above.
229  // If target_unicharset is not NULL, attempts to translate the internal
230  // unichar_ids to the target_unicharset, but falls back to untranslated ids
231  // if the translation should fail.
232  WERD_RES* WordFromOutput(const TBOX& line_box, const NetworkIO& outputs,
233  int word_start, int word_end, float score_ratio,
234  float space_certainty, bool debug,
235  bool use_alternates,
236  const UNICHARSET* target_unicharset,
237  const GenericVector<int>& labels,
238  const GenericVector<int>& label_coords,
239  float scale_factor);
240  // Sets up a word with the ratings matrix and fake blobs with boxes in the
241  // right places.
242  WERD_RES* InitializeWord(const TBOX& line_box, int word_start, int word_end,
243  float space_certainty, bool use_alternates,
244  const UNICHARSET* target_unicharset,
245  const GenericVector<int>& labels,
246  const GenericVector<int>& label_coords,
247  float scale_factor);
248 
249  // Converts an array of labels to utf-8, whether or not the labels are
250  // augmented with character boundaries.
251  STRING DecodeLabels(const GenericVector<int>& labels);
252 
253  // Displays the forward results in a window with the characters and
254  // boundaries as determined by the labels and label_coords.
255  void DisplayForward(const NetworkIO& inputs,
256  const GenericVector<int>& labels,
257  const GenericVector<int>& label_coords,
258  const char* window_name,
259  ScrollView** window);
260 
261  protected:
262  // Sets the random seed from the sample_iteration_;
263  void SetRandomSeed() {
264  inT64 seed = static_cast<inT64>(sample_iteration_) * 0x10000001;
265  randomizer_.set_seed(seed);
267  }
268 
269  // Displays the labels and cuts at the corresponding xcoords.
270  // Size of labels should match xcoords.
271  void DisplayLSTMOutput(const GenericVector<int>& labels,
272  const GenericVector<int>& xcoords,
273  int height, ScrollView* window);
274 
275  // Prints debug output detailing the activation path that is implied by the
276  // xcoords.
277  void DebugActivationPath(const NetworkIO& outputs,
278  const GenericVector<int>& labels,
279  const GenericVector<int>& xcoords);
280 
281  // Prints debug output detailing activations and 2nd choice over a range
282  // of positions.
283  void DebugActivationRange(const NetworkIO& outputs, const char* label,
284  int best_choice, int x_start, int x_end);
285 
286  // Converts the network output to a sequence of labels. Outputs labels, scores
287  // and start xcoords of each char, and each null_char_, with an additional
288  // final xcoord for the end of the output.
289  // The conversion method is determined by internal state.
290  void LabelsFromOutputs(const NetworkIO& outputs, float null_thr,
291  GenericVector<int>* labels,
292  GenericVector<int>* xcoords);
293  // Converts the network output to a sequence of labels, using a threshold
294  // on the null_char_ to determine character boundaries. Outputs labels, scores
295  // and start xcoords of each char, and each null_char_, with an additional
296  // final xcoord for the end of the output.
297  // The label output is the one with the highest score in the interval between
298  // null_chars_.
299  void LabelsViaThreshold(const NetworkIO& output,
300  float null_threshold,
301  GenericVector<int>* labels,
302  GenericVector<int>* xcoords);
303  // Converts the network output to a sequence of labels, with scores and
304  // start x-coords of the character labels. Retains the null_char_ character as
305  // the end x-coord, where already present, otherwise the start of the next
306  // character is the end.
307  // The number of labels, scores, and xcoords is always matched, except that
308  // there is always an additional xcoord for the last end position.
309  void LabelsViaCTC(const NetworkIO& output,
310  GenericVector<int>* labels,
311  GenericVector<int>* xcoords);
312  // As LabelsViaCTC except that this function constructs the best path that
313  // contains only legal sequences of subcodes for recoder_.
314  void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
315  GenericVector<int>* xcoords);
316  // Converts the network output to a sequence of labels, with scores, using
317  // the simple character model (each position is a char, and the null_char_ is
318  // mainly intended for tail padding.)
319  void LabelsViaSimpleText(const NetworkIO& output,
320  GenericVector<int>* labels,
321  GenericVector<int>* xcoords);
322 
323  // Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range.
324  // Handles either LSTM labels or direct unichar-ids.
325  // Score ratio determines the worst ratio between top choice and remainder.
326  // If target_unicharset is not NULL, attempts to translate to the target
327  // unicharset, returning NULL on failure.
328  BLOB_CHOICE_LIST* GetBlobChoices(int col, int row, bool debug,
329  const NetworkIO& output,
330  const UNICHARSET* target_unicharset,
331  int x_start, int x_end, float score_ratio);
332 
333  // Adds to the given iterator, the blob choices for the target_unicharset
334  // that correspond to the given LSTM unichar_id.
335  // Returns false if unicharset translation failed.
336  bool AddBlobChoices(int unichar_id, float rating, float certainty, int col,
337  int row, const UNICHARSET* target_unicharset,
338  BLOB_CHOICE_IT* bc_it);
339 
340  // Returns a string corresponding to the label starting at start. Sets *end
341  // to the next start and if non-null, *decoded to the unichar id.
342  const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
343  int* decoded);
344 
345  // Returns a string corresponding to a given single label id, falling back to
346  // a default of ".." for part of a multi-label unichar-id.
347  const char* DecodeSingleLabel(int label);
348 
349  protected:
350  // The network hierarchy.
352  // The unicharset. Only the unicharset element is serialized.
353  // Has to be a CCUtil, so Dict can point to it.
355  // For backward compatibility, recoder_ is serialized iff
356  // training_flags_ & TF_COMPRESS_UNICHARSET.
357  // Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
359 
360  // ==Training parameters that are serialized to provide a record of them.==
362  // Flags used to determine the training method of the network.
363  // See enum TrainingFlags above.
365  // Number of actual backward training steps used.
367  // Index into training sample set. sample_iteration >= training_iteration_.
369  // Index in softmax of null character. May take the value UNICHAR_BROKEN or
370  // ccutil_.unicharset.size().
372  // Range used for the initial random numbers in the weights.
374  // Learning rate and momentum multipliers of deltas in backprop.
376  float momentum_;
377 
378  // === NOT SERIALIZED.
381  // Language model (optional) to use with the beam search.
383  // Beam search held between uses to optimize memory allocation/use.
385 
386  // == Debugging parameters.==
387  // Recognition debug display window.
389 };
390 
391 } // namespace tesseract.
392 
393 #endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_
float GetLayerLearningRate(const STRING &id) const
LossType loss_type() const
Definition: static_shape.h:48
int64_t inT64
Definition: host.h:40
int32_t inT32
Definition: host.h:38
BLOB_CHOICE_LIST * GetBlobChoices(int col, int row, bool debug, const NetworkIO &output, const UNICHARSET *target_unicharset, int x_start, int x_end, float score_ratio)
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, bool use_alternates, const UNICHARSET *target_unicharset, const TBOX &line_box, float score_ratio, bool one_word, PointerVector< WERD_RES > *words)
inT32 IntRand()
Definition: helpers.h:55
LossType OutputLossType() const
NetworkType type() const
Definition: network.h:112
WERD_RES * WordFromOutput(const TBOX &line_box, const NetworkIO &outputs, int word_start, int word_end, float score_ratio, float space_certainty, bool debug, bool use_alternates, const UNICHARSET *target_unicharset, const GenericVector< int > &labels, const GenericVector< int > &label_coords, float scale_factor)
const UNICHARSET & GetUnicharset() const
Network * GetLayer(const char *id) const
Definition: plumbing.cpp:148
NetworkScratch scratch_space_
void ScaleLearningRate(double factor)
GenericVector< STRING > EnumerateLayers() const
int size() const
Definition: genericvector.h:72
bool LoadDictionary(const char *lang, TessdataManager *mgr)
#define ASSERT_HOST(x)
Definition: errcode.h:84
void WordsFromOutputs(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > label_coords, const TBOX &line_box, bool debug, bool use_alternates, bool one_word, float score_ratio, float scale_factor, const UNICHARSET *target_unicharset, PointerVector< WERD_RES > *words)
CachingStrategy CacheStrategy() const
const char * DecodeSingleLabel(int label)
Definition: strngs.h:45
void LabelsViaReEncode(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
const Dict * GetDict() const
void LabelsViaSimpleText(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
void LabelsFromOutputs(const NetworkIO &outputs, float null_thr, GenericVector< int > *labels, GenericVector< int > *xcoords)
Network * GetLayer(const STRING &id) const
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
UNICHARSET unicharset
Definition: ccutil.h:68
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133
double learning_rate() const
void ScaleLayerLearningRate(const char *id, double factor)
Definition: plumbing.h:108
Definition: rect.h:30
void LabelsViaCTC(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
CachingStrategy
Definition: imagedata.h:40
int NumInputs() const
Definition: network.h:120
void SetIteration(int iteration)
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
bool AddBlobChoices(int unichar_id, float rating, float certainty, int col, int row, const UNICHARSET *target_unicharset, BLOB_CHOICE_IT *bc_it)
void LabelsViaThreshold(const NetworkIO &output, float null_threshold, GenericVector< int > *labels, GenericVector< int > *xcoords)
void set_seed(uinT64 seed)
Definition: helpers.h:45
STRING DecodeLabels(const GenericVector< int > &labels)
int NumOutputs() const
Definition: network.h:123
void DisplayLSTMOutput(const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)
WERD_RES * InitializeWord(const TBOX &line_box, int word_start, int word_end, float space_certainty, bool use_alternates, const UNICHARSET *target_unicharset, const GenericVector< int > &labels, const GenericVector< int > &label_coords, float scale_factor)
bool Serialize(TFile *fp) const
const char * DecodeLabel(const GenericVector< int > &labels, int start, int *end, int *decoded)
RecodeBeamSearch * search_
float LayerLearningRate(const char *id) const
Definition: plumbing.h:102
void ScaleLayerLearningRate(const STRING &id, double factor)