tesseract  4.00.00dev
lstmrecognizer.cpp
Go to the documentation of this file.
1 // File: lstmrecognizer.cpp
3 // Description: Top-level line recognizer class for LSTM-based networks.
4 // Author: Ray Smith
5 // Created: Thu May 02 10:59:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 // Include automatically generated configuration file if running autoconf.
20 #ifdef HAVE_CONFIG_H
21 #include "config_auto.h"
22 #endif
23 
24 #include "lstmrecognizer.h"
25 
26 #include "allheaders.h"
27 #include "callcpp.h"
28 #include "dict.h"
29 #include "genericheap.h"
30 #include "helpers.h"
31 #include "imagedata.h"
32 #include "input.h"
33 #include "lstm.h"
34 #include "normalis.h"
35 #include "pageres.h"
36 #include "ratngs.h"
37 #include "recodebeam.h"
38 #include "scrollview.h"
39 #include "shapetable.h"
40 #include "statistc.h"
41 #include "tprintf.h"
42 
43 namespace tesseract {
44 
45 // Max number of blob choices to return in any given position.
46 const int kMaxChoices = 4;
47 // Default ratio between dict and non-dict words.
48 const double kDictRatio = 2.25;
49 // Default certainty offset to give the dictionary a chance.
50 const double kCertOffset = -0.085;
51 
53  : network_(NULL),
54  training_flags_(0),
55  training_iteration_(0),
56  sample_iteration_(0),
57  null_char_(UNICHAR_BROKEN),
58  weight_range_(0.0f),
59  learning_rate_(0.0f),
60  momentum_(0.0f),
61  dict_(NULL),
62  search_(NULL),
63  debug_win_(NULL) {}
64 
66  delete network_;
67  delete dict_;
68  delete search_;
69 }
70 
71 // Writes to the given file. Returns false in case of error.
73  if (!network_->Serialize(fp)) return false;
74  if (!GetUnicharset().save_to_file(fp)) return false;
75  if (!network_str_.Serialize(fp)) return false;
76  if (fp->FWrite(&training_flags_, sizeof(training_flags_), 1) != 1)
77  return false;
78  if (fp->FWrite(&training_iteration_, sizeof(training_iteration_), 1) != 1)
79  return false;
80  if (fp->FWrite(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
81  return false;
82  if (fp->FWrite(&null_char_, sizeof(null_char_), 1) != 1) return false;
83  if (fp->FWrite(&weight_range_, sizeof(weight_range_), 1) != 1) return false;
84  if (fp->FWrite(&learning_rate_, sizeof(learning_rate_), 1) != 1) return false;
85  if (fp->FWrite(&momentum_, sizeof(momentum_), 1) != 1) return false;
86  if (IsRecoding() && !recoder_.Serialize(fp)) return false;
87  return true;
88 }
89 
90 // Reads from the given file. Returns false in case of error.
92  delete network_;
94  if (network_ == NULL) return false;
95  if (!ccutil_.unicharset.load_from_file(fp, false)) return false;
96  if (!network_str_.DeSerialize(fp)) return false;
97  if (fp->FReadEndian(&training_flags_, sizeof(training_flags_), 1) != 1)
98  return false;
99  if (fp->FReadEndian(&training_iteration_, sizeof(training_iteration_), 1) !=
100  1)
101  return false;
102  if (fp->FReadEndian(&sample_iteration_, sizeof(sample_iteration_), 1) != 1)
103  return false;
104  if (fp->FReadEndian(&null_char_, sizeof(null_char_), 1) != 1) return false;
105  if (fp->FReadEndian(&weight_range_, sizeof(weight_range_), 1) != 1)
106  return false;
107  if (fp->FReadEndian(&learning_rate_, sizeof(learning_rate_), 1) != 1)
108  return false;
109  if (fp->FReadEndian(&momentum_, sizeof(momentum_), 1) != 1) return false;
110  if (IsRecoding()) {
111  if (!recoder_.DeSerialize(fp)) return false;
112  RecodedCharID code;
114  if (code(0) != UNICHAR_SPACE) {
115  tprintf("Space was garbled in recoding!!\n");
116  return false;
117  }
118  }
121  return true;
122 }
123 
124 // Loads the dictionary if possible from the traineddata file.
125 // Prints a warning message, and returns false but otherwise fails silently
126 // and continues to work without it if loading fails.
127 // Note that dictionary load is independent from DeSerialize, but dependent
128 // on the unicharset matching. This enables training to deserialize a model
129 // from checkpoint or restore without having to go back and reload the
130 // dictionary.
132  delete dict_;
133  dict_ = new Dict(&ccutil_);
135  dict_->LoadLSTM(lang, mgr);
136  if (dict_->FinishLoad()) return true; // Success.
137  tprintf("Failed to load any lstm-specific dictionaries for lang %s!!\n",
138  lang);
139  delete dict_;
140  dict_ = NULL;
141  return false;
142 }
143 
144 // Recognizes the line image, contained within image_data, returning the
145 // ratings matrix and matching box_word for each WERD_RES in the output.
146 void LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
147  bool debug, double worst_dict_cert,
148  bool use_alternates,
149  const UNICHARSET* target_unicharset,
150  const TBOX& line_box, float score_ratio,
151  bool one_word,
152  PointerVector<WERD_RES>* words) {
153  NetworkIO outputs;
154  float label_threshold = use_alternates ? 0.75f : 0.0f;
155  float scale_factor;
156  NetworkIO inputs;
157  if (!RecognizeLine(image_data, invert, debug, false, label_threshold,
158  &scale_factor, &inputs, &outputs))
159  return;
160  if (IsRecoding()) {
161  if (search_ == NULL) {
162  search_ =
164  }
165  search_->Decode(outputs, kDictRatio, kCertOffset, worst_dict_cert, NULL);
166  search_->ExtractBestPathAsWords(line_box, scale_factor, debug,
167  &GetUnicharset(), words);
168  } else {
169  GenericVector<int> label_coords;
170  GenericVector<int> labels;
171  LabelsFromOutputs(outputs, label_threshold, &labels, &label_coords);
172  WordsFromOutputs(outputs, labels, label_coords, line_box, debug,
173  use_alternates, one_word, score_ratio, scale_factor,
174  target_unicharset, words);
175  }
176 }
177 
178 // Builds a set of tesseract-compatible WERD_RESs aligned to line_box,
179 // corresponding to the network output in outputs, labels, label_coords.
180 // one_word generates a single word output, that may include spaces inside.
181 // use_alternates generates alternative BLOB_CHOICEs and segmentation paths.
182 // If not NULL, we attempt to translate the output to target_unicharset, but do
183 // not guarantee success, due to mismatches. In that case the output words are
184 // marked with our UNICHARSET, not the caller's.
186  const NetworkIO& outputs, const GenericVector<int>& labels,
187  const GenericVector<int> label_coords, const TBOX& line_box, bool debug,
188  bool use_alternates, bool one_word, float score_ratio, float scale_factor,
189  const UNICHARSET* target_unicharset, PointerVector<WERD_RES>* words) {
190  // Convert labels to unichar-ids.
191  int word_end = 0;
192  float prev_space_cert = 0.0f;
193  for (int i = 0; i < labels.size(); i = word_end) {
194  word_end = i + 1;
195  if (labels[i] == null_char_ || labels[i] == UNICHAR_SPACE) {
196  continue;
197  }
198  float space_cert = 0.0f;
199  if (one_word) {
200  word_end = labels.size();
201  } else {
202  // Find the end of the word at the first null_char_ that leads to the
203  // first UNICHAR_SPACE.
204  while (word_end < labels.size() && labels[word_end] != UNICHAR_SPACE)
205  ++word_end;
206  if (word_end < labels.size()) {
207  float rating;
208  outputs.ScoresOverRange(label_coords[word_end],
209  label_coords[word_end] + 1, UNICHAR_SPACE,
210  null_char_, &rating, &space_cert);
211  }
212  while (word_end > i && labels[word_end - 1] == null_char_) --word_end;
213  }
214  ASSERT_HOST(word_end > i);
215  // Create a WERD_RES for the output word.
216  if (debug)
217  tprintf("Creating word from outputs over [%d,%d)\n", i, word_end);
218  WERD_RES* word =
219  WordFromOutput(line_box, outputs, i, word_end, score_ratio,
220  MIN(prev_space_cert, space_cert), debug,
221  use_alternates && !SimpleTextOutput(), target_unicharset,
222  labels, label_coords, scale_factor);
223  if (word == NULL && target_unicharset != NULL) {
224  // Unicharset translation failed - use decoder_ instead, and disable
225  // the segmentation search on output, as it won't understand the encoding.
226  word = WordFromOutput(line_box, outputs, i, word_end, score_ratio,
227  MIN(prev_space_cert, space_cert), debug, false,
228  NULL, labels, label_coords, scale_factor);
229  }
230  prev_space_cert = space_cert;
231  words->push_back(word);
232  }
233 }
234 
235 // Helper computes min and mean best results in the output.
236 void LSTMRecognizer::OutputStats(const NetworkIO& outputs, float* min_output,
237  float* mean_output, float* sd) {
238  const int kOutputScale = MAX_INT8;
239  STATS stats(0, kOutputScale + 1);
240  for (int t = 0; t < outputs.Width(); ++t) {
241  int best_label = outputs.BestLabel(t, NULL);
242  if (best_label != null_char_ || t == 0) {
243  float best_output = outputs.f(t)[best_label];
244  stats.add(static_cast<int>(kOutputScale * best_output), 1);
245  }
246  }
247  *min_output = static_cast<float>(stats.min_bucket()) / kOutputScale;
248  *mean_output = stats.mean() / kOutputScale;
249  *sd = stats.sd() / kOutputScale;
250 }
251 
252 // Recognizes the image_data, returning the labels,
253 // scores, and corresponding pairs of start, end x-coords in coords.
254 // If label_threshold is positive, uses it for making the labels, otherwise
255 // uses standard ctc.
256 bool LSTMRecognizer::RecognizeLine(const ImageData& image_data, bool invert,
257  bool debug, bool re_invert,
258  float label_threshold, float* scale_factor,
259  NetworkIO* inputs, NetworkIO* outputs) {
260  // Maximum width of image to train on.
261  const int kMaxImageWidth = 2560;
262  // This ensures consistent recognition results.
263  SetRandomSeed();
264  int min_width = network_->XScaleFactor();
265  Pix* pix = Input::PrepareLSTMInputs(image_data, network_, min_width,
266  &randomizer_, scale_factor);
267  if (pix == NULL) {
268  tprintf("Line cannot be recognized!!\n");
269  return false;
270  }
271  if (network_->IsTraining() && pixGetWidth(pix) > kMaxImageWidth) {
272  tprintf("Image too large to learn!! Size = %dx%d\n", pixGetWidth(pix),
273  pixGetHeight(pix));
274  pixDestroy(&pix);
275  return false;
276  }
277  // Reduction factor from image to coords.
278  *scale_factor = min_width / *scale_factor;
279  inputs->set_int_mode(IsIntMode());
280  SetRandomSeed();
282  network_->Forward(debug, *inputs, NULL, &scratch_space_, outputs);
283  // Check for auto inversion.
284  float pos_min, pos_mean, pos_sd;
285  OutputStats(*outputs, &pos_min, &pos_mean, &pos_sd);
286  if (invert && pos_min < 0.5) {
287  // Run again inverted and see if it is any better.
288  NetworkIO inv_inputs, inv_outputs;
289  inv_inputs.set_int_mode(IsIntMode());
290  SetRandomSeed();
291  pixInvert(pix, pix);
293  &inv_inputs);
294  network_->Forward(debug, inv_inputs, NULL, &scratch_space_, &inv_outputs);
295  float inv_min, inv_mean, inv_sd;
296  OutputStats(inv_outputs, &inv_min, &inv_mean, &inv_sd);
297  if (inv_min > pos_min && inv_mean > pos_mean && inv_sd < pos_sd) {
298  // Inverted did better. Use inverted data.
299  if (debug) {
300  tprintf("Inverting image: old min=%g, mean=%g, sd=%g, inv %g,%g,%g\n",
301  pos_min, pos_mean, pos_sd, inv_min, inv_mean, inv_sd);
302  }
303  *outputs = inv_outputs;
304  *inputs = inv_inputs;
305  } else if (re_invert) {
306  // Inverting was not an improvement, so undo and run again, so the
307  // outputs match the best forward result.
308  SetRandomSeed();
309  network_->Forward(debug, *inputs, NULL, &scratch_space_, outputs);
310  }
311  }
312  pixDestroy(&pix);
313  if (debug) {
314  GenericVector<int> labels, coords;
315  LabelsFromOutputs(*outputs, label_threshold, &labels, &coords);
316  DisplayForward(*inputs, labels, coords, "LSTMForward", &debug_win_);
317  DebugActivationPath(*outputs, labels, coords);
318  }
319  return true;
320 }
321 
322 // Returns a tesseract-compatible WERD_RES from the line recognizer outputs.
323 // line_box should be the bounding box of the line image in the main image,
324 // outputs the output of the network,
325 // [word_start, word_end) the interval over which to convert,
326 // score_ratio for choosing alternate classifier choices,
327 // use_alternates to control generation of alternative segmentations,
328 // labels, label_coords, scale_factor from RecognizeLine above.
329 // If target_unicharset is not NULL, attempts to translate the internal
330 // unichar_ids to the target_unicharset, but falls back to untranslated ids
331 // if the translation should fail.
333  const TBOX& line_box, const NetworkIO& outputs, int word_start,
334  int word_end, float score_ratio, float space_certainty, bool debug,
335  bool use_alternates, const UNICHARSET* target_unicharset,
336  const GenericVector<int>& labels, const GenericVector<int>& label_coords,
337  float scale_factor) {
338  WERD_RES* word_res = InitializeWord(
339  line_box, word_start, word_end, space_certainty, use_alternates,
340  target_unicharset, labels, label_coords, scale_factor);
341  int max_blob_run = word_res->ratings->bandwidth();
342  for (int width = 1; width <= max_blob_run; ++width) {
343  int col = 0;
344  for (int i = word_start; i + width <= word_end; ++i) {
345  if (labels[i] != null_char_) {
346  // Starting at i, use width labels, but stop at the next null_char_.
347  // This forms all combinations of blobs between regions of null_char_.
348  int j = i + 1;
349  while (j - i < width && labels[j] != null_char_) ++j;
350  if (j - i == width) {
351  // Make the blob choices.
352  int end_coord = label_coords[j];
353  if (j < word_end && labels[j] == null_char_)
354  end_coord = label_coords[j + 1];
355  BLOB_CHOICE_LIST* choices = GetBlobChoices(
356  col, col + width - 1, debug, outputs, target_unicharset,
357  label_coords[i], end_coord, score_ratio);
358  if (choices == NULL) {
359  delete word_res;
360  return NULL;
361  }
362  word_res->ratings->put(col, col + width - 1, choices);
363  }
364  ++col;
365  }
366  }
367  }
368  if (use_alternates) {
369  // Merge adjacent single results over null_char boundaries.
370  int col = 0;
371  for (int i = word_start; i + 2 < word_end; ++i) {
372  if (labels[i] != null_char_ && labels[i + 1] == null_char_ &&
373  labels[i + 2] != null_char_ &&
374  (i == word_start || labels[i - 1] == null_char_) &&
375  (i + 3 == word_end || labels[i + 3] == null_char_)) {
376  int end_coord = label_coords[i + 3];
377  if (i + 3 < word_end && labels[i + 3] == null_char_)
378  end_coord = label_coords[i + 4];
379  BLOB_CHOICE_LIST* choices =
380  GetBlobChoices(col, col + 1, debug, outputs, target_unicharset,
381  label_coords[i], end_coord, score_ratio);
382  if (choices == NULL) {
383  delete word_res;
384  return NULL;
385  }
386  word_res->ratings->put(col, col + 1, choices);
387  }
388  if (labels[i] != null_char_) ++col;
389  }
390  } else {
392  }
393  return word_res;
394 }
395 
396 // Sets up a word with the ratings matrix and fake blobs with boxes in the
397 // right places.
398 WERD_RES* LSTMRecognizer::InitializeWord(const TBOX& line_box, int word_start,
399  int word_end, float space_certainty,
400  bool use_alternates,
401  const UNICHARSET* target_unicharset,
402  const GenericVector<int>& labels,
403  const GenericVector<int>& label_coords,
404  float scale_factor) {
405  // Make a fake blob for each non-zero label.
406  C_BLOB_LIST blobs;
407  C_BLOB_IT b_it(&blobs);
408  // num_blobs is the length of the diagonal of the ratings matrix.
409  int num_blobs = 0;
410  // max_blob_run is the diagonal width of the ratings matrix
411  int max_blob_run = 0;
412  int blob_run = 0;
413  for (int i = word_start; i < word_end; ++i) {
414  if (IsRecoding() && !recoder_.IsValidFirstCode(labels[i])) continue;
415  if (labels[i] != null_char_) {
416  // Make a fake blob.
417  TBOX box(label_coords[i], 0, label_coords[i + 1], line_box.height());
418  box.scale(scale_factor);
419  box.move(ICOORD(line_box.left(), line_box.bottom()));
420  box.set_top(line_box.top());
421  b_it.add_after_then_move(C_BLOB::FakeBlob(box));
422  ++num_blobs;
423  ++blob_run;
424  }
425  if (labels[i] == null_char_ || i + 1 == word_end) {
426  if (blob_run > max_blob_run)
427  max_blob_run = blob_run;
428  }
429  }
430  if (!use_alternates) max_blob_run = 1;
431  ASSERT_HOST(label_coords.size() >= word_end);
432  // Make a fake word from the blobs.
433  WERD* word = new WERD(&blobs, word_start > 1 ? 1 : 0, NULL);
434  // Make a WERD_RES from the word.
435  WERD_RES* word_res = new WERD_RES(word);
436  word_res->uch_set =
437  target_unicharset != NULL ? target_unicharset : &GetUnicharset();
438  word_res->combination = true; // Give it ownership of the word.
439  word_res->space_certainty = space_certainty;
440  word_res->ratings = new MATRIX(num_blobs, max_blob_run);
441  return word_res;
442 }
443 
444 // Converts an array of labels to utf-8, whether or not the labels are
445 // augmented with character boundaries.
447  STRING result;
448  int end = 1;
449  for (int start = 0; start < labels.size(); start = end) {
450  if (labels[start] == null_char_) {
451  end = start + 1;
452  } else {
453  result += DecodeLabel(labels, start, &end, NULL);
454  }
455  }
456  return result;
457 }
458 
459 // Displays the forward results in a window with the characters and
460 // boundaries as determined by the labels and label_coords.
462  const GenericVector<int>& labels,
463  const GenericVector<int>& label_coords,
464  const char* window_name,
465  ScrollView** window) {
466 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
467  Pix* input_pix = inputs.ToPix();
468  Network::ClearWindow(false, window_name, pixGetWidth(input_pix),
469  pixGetHeight(input_pix), window);
470  int line_height = Network::DisplayImage(input_pix, *window);
471  DisplayLSTMOutput(labels, label_coords, line_height, *window);
472 #endif // GRAPHICS_DISABLED
473 }
474 
475 // Displays the labels and cuts at the corresponding xcoords.
476 // Size of labels should match xcoords.
478  const GenericVector<int>& xcoords,
479  int height, ScrollView* window) {
480 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
481  int x_scale = network_->XScaleFactor();
482  window->TextAttributes("Arial", height / 4, false, false, false);
483  int end = 1;
484  for (int start = 0; start < labels.size(); start = end) {
485  int xpos = xcoords[start] * x_scale;
486  if (labels[start] == null_char_) {
487  end = start + 1;
488  window->Pen(ScrollView::RED);
489  } else {
490  window->Pen(ScrollView::GREEN);
491  const char* str = DecodeLabel(labels, start, &end, NULL);
492  if (*str == '\\') str = "\\\\";
493  xpos = xcoords[(start + end) / 2] * x_scale;
494  window->Text(xpos, height, str);
495  }
496  window->Line(xpos, 0, xpos, height * 3 / 2);
497  }
498  window->Update();
499 #endif // GRAPHICS_DISABLED
500 }
501 
502 // Prints debug output detailing the activation path that is implied by the
503 // label_coords.
505  const GenericVector<int>& labels,
506  const GenericVector<int>& xcoords) {
507  if (xcoords[0] > 0)
508  DebugActivationRange(outputs, "<null>", null_char_, 0, xcoords[0]);
509  int end = 1;
510  for (int start = 0; start < labels.size(); start = end) {
511  if (labels[start] == null_char_) {
512  end = start + 1;
513  DebugActivationRange(outputs, "<null>", null_char_, xcoords[start],
514  xcoords[end]);
515  continue;
516  } else {
517  int decoded;
518  const char* label = DecodeLabel(labels, start, &end, &decoded);
519  DebugActivationRange(outputs, label, labels[start], xcoords[start],
520  xcoords[start + 1]);
521  for (int i = start + 1; i < end; ++i) {
522  DebugActivationRange(outputs, DecodeSingleLabel(labels[i]), labels[i],
523  xcoords[i], xcoords[i + 1]);
524  }
525  }
526  }
527 }
528 
529 // Prints debug output detailing activations and 2nd choice over a range
530 // of positions.
532  const char* label, int best_choice,
533  int x_start, int x_end) {
534  tprintf("%s=%d On [%d, %d), scores=", label, best_choice, x_start, x_end);
535  double max_score = 0.0;
536  double mean_score = 0.0;
537  int width = x_end - x_start;
538  for (int x = x_start; x < x_end; ++x) {
539  const float* line = outputs.f(x);
540  double score = line[best_choice] * 100.0;
541  if (score > max_score) max_score = score;
542  mean_score += score / width;
543  int best_c = 0;
544  double best_score = 0.0;
545  for (int c = 0; c < outputs.NumFeatures(); ++c) {
546  if (c != best_choice && line[c] > best_score) {
547  best_c = c;
548  best_score = line[c];
549  }
550  }
551  tprintf(" %.3g(%s=%d=%.3g)", score, DecodeSingleLabel(best_c), best_c,
552  best_score * 100.0);
553  }
554  tprintf(", Mean=%g, max=%g\n", mean_score, max_score);
555 }
556 
557 // Helper returns true if the null_char is the winner at t, and it beats the
558 // null_threshold, or the next choice is space, in which case we will use the
559 // null anyway.
560 static bool NullIsBest(const NetworkIO& output, float null_thr,
561  int null_char, int t) {
562  if (output.f(t)[null_char] >= null_thr) return true;
563  if (output.BestLabel(t, null_char, null_char, NULL) != UNICHAR_SPACE)
564  return false;
565  return output.f(t)[null_char] > output.f(t)[UNICHAR_SPACE];
566 }
567 
568 // Converts the network output to a sequence of labels. Outputs labels, scores
569 // and start xcoords of each char, and each null_char_, with an additional
570 // final xcoord for the end of the output.
571 // The conversion method is determined by internal state.
572 void LSTMRecognizer::LabelsFromOutputs(const NetworkIO& outputs, float null_thr,
573  GenericVector<int>* labels,
574  GenericVector<int>* xcoords) {
575  if (SimpleTextOutput()) {
576  LabelsViaSimpleText(outputs, labels, xcoords);
577  } else if (IsRecoding()) {
578  LabelsViaReEncode(outputs, labels, xcoords);
579  } else if (null_thr <= 0.0) {
580  LabelsViaCTC(outputs, labels, xcoords);
581  } else {
582  LabelsViaThreshold(outputs, null_thr, labels, xcoords);
583  }
584 }
585 
586 // Converts the network output to a sequence of labels, using a threshold
587 // on the null_char_ to determine character boundaries. Outputs labels, scores
588 // and start xcoords of each char, and each null_char_, with an additional
589 // final xcoord for the end of the output.
590 // The label output is the one with the highest score in the interval between
591 // null_chars_.
593  float null_thr,
594  GenericVector<int>* labels,
595  GenericVector<int>* xcoords) {
596  labels->truncate(0);
597  xcoords->truncate(0);
598  int width = output.Width();
599  int t = 0;
600  // Skip any initial non-char.
601  while (t < width && NullIsBest(output, null_thr, null_char_, t)) {
602  ++t;
603  }
604  while (t < width) {
605  ASSERT_HOST(!std::isnan(output.f(t)[null_char_]));
606  int label = output.BestLabel(t, null_char_, null_char_, NULL);
607  int char_start = t++;
608  while (t < width && !NullIsBest(output, null_thr, null_char_, t) &&
609  label == output.BestLabel(t, null_char_, null_char_, NULL)) {
610  ++t;
611  }
612  int char_end = t;
613  labels->push_back(label);
614  xcoords->push_back(char_start);
615  // Find the end of the non-char, and compute its score.
616  while (t < width && NullIsBest(output, null_thr, null_char_, t)) {
617  ++t;
618  }
619  if (t > char_end) {
620  labels->push_back(null_char_);
621  xcoords->push_back(char_end);
622  }
623  }
624  xcoords->push_back(width);
625 }
626 
627 // Converts the network output to a sequence of labels, with scores and
628 // start x-coords of the character labels. Retains the null_char_ as the
629 // end x-coord, where already present, otherwise the start of the next
630 // character is the end.
631 // The number of labels, scores, and xcoords is always matched, except that
632 // there is always an additional xcoord for the last end position.
634  GenericVector<int>* labels,
635  GenericVector<int>* xcoords) {
636  labels->truncate(0);
637  xcoords->truncate(0);
638  int width = output.Width();
639  int t = 0;
640  while (t < width) {
641  float score = 0.0f;
642  int label = output.BestLabel(t, &score);
643  labels->push_back(label);
644  xcoords->push_back(t);
645  while (++t < width && output.BestLabel(t, NULL) == label) {
646  }
647  }
648  xcoords->push_back(width);
649 }
650 
651 // As LabelsViaCTC except that this function constructs the best path that
652 // contains only legal sequences of subcodes for CJK.
654  GenericVector<int>* labels,
655  GenericVector<int>* xcoords) {
656  if (search_ == NULL) {
657  search_ =
659  }
660  search_->Decode(output, 1.0, 0.0, RecodeBeamSearch::kMinCertainty, NULL);
661  search_->ExtractBestPathAsLabels(labels, xcoords);
662 }
663 
664 // Converts the network output to a sequence of labels, with scores, using
665 // the simple character model (each position is a char, and the null_char_ is
666 // mainly intended for tail padding.)
668  GenericVector<int>* labels,
669  GenericVector<int>* xcoords) {
670  labels->truncate(0);
671  xcoords->truncate(0);
672  int width = output.Width();
673  for (int t = 0; t < width; ++t) {
674  float score = 0.0f;
675  int label = output.BestLabel(t, &score);
676  if (label != null_char_) {
677  labels->push_back(label);
678  xcoords->push_back(t);
679  }
680  }
681  xcoords->push_back(width);
682 }
683 
684 // Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range.
685 // Handles either LSTM labels or direct unichar-ids.
686 // Score ratio determines the worst ratio between top choice and remainder.
687 // If target_unicharset is not NULL, attempts to translate to the target
688 // unicharset, returning NULL on failure.
690  int col, int row, bool debug, const NetworkIO& output,
691  const UNICHARSET* target_unicharset, int x_start, int x_end,
692  float score_ratio) {
693  float rating = 0.0f, certainty = 0.0f;
694  int label = output.BestChoiceOverRange(x_start, x_end, UNICHAR_SPACE,
695  null_char_, &rating, &certainty);
696  int unichar_id = label == null_char_ ? UNICHAR_SPACE : label;
697  if (debug) {
698  tprintf("Best choice over range %d,%d=unichar%d=%s r = %g, cert=%g\n",
699  x_start, x_end, unichar_id, DecodeSingleLabel(label), rating,
700  certainty);
701  }
702  BLOB_CHOICE_LIST* choices = new BLOB_CHOICE_LIST;
703  BLOB_CHOICE_IT bc_it(choices);
704  if (!AddBlobChoices(unichar_id, rating, certainty, col, row,
705  target_unicharset, &bc_it)) {
706  delete choices;
707  return NULL;
708  }
709  // Get the other choices.
710  double best_cert = certainty;
711  for (int c = 0; c < output.NumFeatures(); ++c) {
712  if (c == label || c == UNICHAR_SPACE || c == null_char_) continue;
713  // Compute the score over the range.
714  output.ScoresOverRange(x_start, x_end, c, null_char_, &rating, &certainty);
715  int unichar_id = c == null_char_ ? UNICHAR_SPACE : c;
716  if (certainty >= best_cert - score_ratio &&
717  !AddBlobChoices(unichar_id, rating, certainty, col, row,
718  target_unicharset, &bc_it)) {
719  delete choices;
720  return NULL;
721  }
722  }
723  choices->sort(&BLOB_CHOICE::SortByRating);
724  if (bc_it.length() > kMaxChoices) {
725  bc_it.move_to_first();
726  for (int i = 0; i < kMaxChoices; ++i)
727  bc_it.forward();
728  while (!bc_it.at_first()) {
729  delete bc_it.extract();
730  bc_it.forward();
731  }
732  }
733  return choices;
734 }
735 
736 // Adds to the given iterator, the blob choices for the target_unicharset
737 // that correspond to the given LSTM unichar_id.
738 // Returns false if unicharset translation failed.
739 bool LSTMRecognizer::AddBlobChoices(int unichar_id, float rating,
740  float certainty, int col, int row,
741  const UNICHARSET* target_unicharset,
742  BLOB_CHOICE_IT* bc_it) {
743  int target_id = unichar_id;
744  if (target_unicharset != NULL) {
745  const char* utf8 = GetUnicharset().id_to_unichar(unichar_id);
746  if (target_unicharset->contains_unichar(utf8)) {
747  target_id = target_unicharset->unichar_to_id(utf8);
748  } else {
749  return false;
750  }
751  }
752  BLOB_CHOICE* choice = new BLOB_CHOICE(target_id, rating, certainty, -1, 1.0f,
753  static_cast<float>(MAX_INT16), 0.0f,
755  choice->set_matrix_cell(col, row);
756  bc_it->add_after_then_move(choice);
757  return true;
758 }
759 
760 // Returns a string corresponding to the label starting at start. Sets *end
761 // to the next start and if non-null, *decoded to the unichar id.
763  int start, int* end, int* decoded) {
764  *end = start + 1;
765  if (IsRecoding()) {
766  // Decode labels via recoder_.
767  RecodedCharID code;
768  if (labels[start] == null_char_) {
769  if (decoded != NULL) {
770  code.Set(0, null_char_);
771  *decoded = recoder_.DecodeUnichar(code);
772  }
773  return "<null>";
774  }
775  int index = start;
776  while (index < labels.size() &&
778  code.Set(code.length(), labels[index++]);
779  while (index < labels.size() && labels[index] == null_char_) ++index;
780  int uni_id = recoder_.DecodeUnichar(code);
781  // If the next label isn't a valid first code, then we need to continue
782  // extending even if we have a valid uni_id from this prefix.
783  if (uni_id != INVALID_UNICHAR_ID &&
784  (index == labels.size() ||
786  recoder_.IsValidFirstCode(labels[index]))) {
787  *end = index;
788  if (decoded != NULL) *decoded = uni_id;
789  if (uni_id == UNICHAR_SPACE) return " ";
790  return GetUnicharset().get_normed_unichar(uni_id);
791  }
792  }
793  return "<Undecodable>";
794  } else {
795  if (decoded != NULL) *decoded = labels[start];
796  if (labels[start] == null_char_) return "<null>";
797  if (labels[start] == UNICHAR_SPACE) return " ";
798  return GetUnicharset().get_normed_unichar(labels[start]);
799  }
800 }
801 
802 // Returns a string corresponding to a given single label id, falling back to
803 // a default of ".." for part of a multi-label unichar-id.
804 const char* LSTMRecognizer::DecodeSingleLabel(int label) {
805  if (label == null_char_) return "<null>";
806  if (IsRecoding()) {
807  // Decode label via recoder_.
808  RecodedCharID code;
809  code.Set(0, label);
810  label = recoder_.DecodeUnichar(code);
811  if (label == INVALID_UNICHAR_ID) return ".."; // Part of a bigger code.
812  }
813  if (label == UNICHAR_SPACE) return " ";
814  return GetUnicharset().get_normed_unichar(label);
815 }
816 
817 } // namespace tesseract.
virtual int XScaleFactor() const
Definition: network.h:195
void Line(int x1, int y1, int x2, int y2)
Definition: scrollview.cpp:538
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
static C_BLOB * FakeBlob(const TBOX &box)
Definition: stepblob.cpp:238
int Width() const
Definition: networkio.h:107
float * f(int t)
Definition: networkio.h:115
BLOB_CHOICE_LIST * GetBlobChoices(int col, int row, bool debug, const NetworkIO &output, const UNICHARSET *target_unicharset, int x_start, int x_end, float score_ratio)
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:140
bool IsValidFirstCode(int code) const
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, bool use_alternates, const UNICHARSET *target_unicharset, const TBOX &line_box, float score_ratio, bool one_word, PointerVector< WERD_RES > *words)
bool contains_unichar(const char *const unichar_repr) const
Definition: unicharset.cpp:644
static const float kMinCertainty
Definition: recodebeam.h:213
int BestLabel(int t, float *score) const
Definition: networkio.h:161
bool DeSerialize(bool swap, FILE *fp)
Definition: strngs.cpp:163
static void PreparePixInput(const StaticShape &shape, const Pix *pix, TRand *randomizer, NetworkIO *input)
Definition: input.cpp:117
static Pix * PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
Definition: input.cpp:89
static int SortByRating(const void *p1, const void *p2)
Definition: ratngs.h:192
#define MAX_INT16
Definition: host.h:61
static const int kMaxCodeLen
WERD_RES * WordFromOutput(const TBOX &line_box, const NetworkIO &outputs, int word_start, int word_end, float score_ratio, float space_certainty, bool debug, bool use_alternates, const UNICHARSET *target_unicharset, const GenericVector< int > &labels, const GenericVector< int > &label_coords, float scale_factor)
int push_back(T object)
const UNICHARSET & GetUnicharset() const
#define tprintf(...)
Definition: tprintf.h:31
NetworkScratch scratch_space_
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:203
void set_matrix_cell(int col, int row)
Definition: ratngs.h:156
const double kCertOffset
bool IsTraining() const
Definition: network.h:115
int FReadEndian(void *buffer, int size, int count)
Definition: serialis.cpp:97
void truncate(int size)
void ExtractBestPathAsLabels(GenericVector< int > *labels, GenericVector< int > *xcoords) const
Definition: recodebeam.cpp:100
int size() const
Definition: genericvector.h:72
void ScoresOverRange(int t_start, int t_end, int choice, int null_ch, float *rating, float *certainty) const
Definition: networkio.cpp:450
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
Definition: network.h:248
void set_int_mode(bool is_quantized)
Definition: networkio.h:130
bool LoadDictionary(const char *lang, TessdataManager *mgr)
#define ASSERT_HOST(x)
Definition: errcode.h:84
const char * id_to_unichar(UNICHAR_ID id) const
Definition: unicharset.cpp:266
inT16 left() const
Definition: rect.h:68
virtual void CacheXScaleFactor(int factor)
Definition: network.h:201
MATRIX * ratings
Definition: pageres.h:215
void LoadLSTM(const STRING &lang, TessdataManager *data_file)
Definition: dict.cpp:306
void ExtractBestPathAsWords(const TBOX &line_box, float scale_factor, bool debug, const UNICHARSET *unicharset, PointerVector< WERD_RES > *words)
Definition: recodebeam.cpp:138
void scale(const float f)
Definition: rect.h:171
void WordsFromOutputs(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > label_coords, const TBOX &line_box, bool debug, bool use_alternates, bool one_word, float score_ratio, float scale_factor, const UNICHARSET *target_unicharset, PointerVector< WERD_RES > *words)
BOOL8 combination
Definition: pageres.h:318
bool Serialize(TFile *fp) const
virtual StaticShape InputShape() const
Definition: network.h:127
const char * DecodeSingleLabel(int label)
Definition: strngs.h:45
void LabelsViaReEncode(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
void FakeWordFromRatings(PermuterType permuter)
Definition: pageres.cpp:893
static void Update()
Definition: scrollview.cpp:715
bool Serialize(FILE *fp) const
Definition: strngs.cpp:148
int DecodeUnichar(const RecodedCharID &code) const
inT32 min_bucket() const
Definition: statistc.cpp:206
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:309
#define isnan(x)
Definition: mathfix.h:31
void LabelsViaSimpleText(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
void LabelsFromOutputs(const NetworkIO &outputs, float null_thr, GenericVector< int > *labels, GenericVector< int > *xcoords)
void DebugActivationRange(const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
int bandwidth() const
Definition: matrix.h:523
UNICHARSET unicharset
Definition: ccutil.h:68
int FWrite(const void *buffer, int size, int count)
Definition: serialis.cpp:148
void add(inT32 value, inT32 count)
Definition: statistc.cpp:101
inT16 top() const
Definition: rect.h:54
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:348
Definition: rect.h:30
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset)
Definition: recodebeam.cpp:76
void LabelsViaCTC(const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
void put(ICOORD pos, const T &thing)
Definition: matrix.h:215
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
#define MIN(x, y)
Definition: ndminx.h:28
Definition: matrix.h:563
void Set(int index, int value)
bool save_to_file(const char *const filename) const
Definition: unicharset.h:308
inT16 height() const
Definition: rect.h:104
int NumFeatures() const
Definition: networkio.h:111
void OutputStats(const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
bool AddBlobChoices(int unichar_id, float rating, float certainty, int col, int row, const UNICHARSET *target_unicharset, BLOB_CHOICE_IT *bc_it)
const double kDictRatio
void Text(int x, int y, const char *mystring)
Definition: scrollview.cpp:658
Definition: statistc.h:33
void LabelsViaThreshold(const NetworkIO &output, float null_threshold, GenericVector< int > *labels, GenericVector< int > *xcoords)
#define MAX_INT8
Definition: host.h:60
inT16 bottom() const
Definition: rect.h:61
const int kMaxChoices
static DawgCache * GlobalDawgCache()
Definition: dict.cpp:198
double mean() const
Definition: statistc.cpp:135
const UNICHARSET * uch_set
Definition: pageres.h:192
static int DisplayImage(Pix *pix, ScrollView *window)
Definition: network.cpp:332
STRING DecodeLabels(const GenericVector< int > &labels)
void SetupForLoad(DawgCache *dawg_cache)
Definition: dict.cpp:206
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153
float space_certainty
Definition: pageres.h:300
UNICHAR_ID unichar_to_id(const char *const unichar_repr) const
Definition: unicharset.cpp:194
Pix * ToPix() const
Definition: networkio.cpp:286
Definition: werd.h:60
void DisplayLSTMOutput(const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)
WERD_RES * InitializeWord(const TBOX &line_box, int word_start, int word_end, float space_certainty, bool use_alternates, const UNICHARSET *target_unicharset, const GenericVector< int > &labels, const GenericVector< int > &label_coords, float scale_factor)
bool Serialize(TFile *fp) const
bool FinishLoad()
Definition: dict.cpp:327
const char * DecodeLabel(const GenericVector< int > &labels, int start, int *end, int *decoded)
const char * get_normed_unichar(UNICHAR_ID unichar_id) const
Definition: unicharset.h:788
int BestChoiceOverRange(int t_start, int t_end, int not_this, int null_ch, float *rating, float *certainty) const
Definition: networkio.cpp:431
double sd() const
Definition: statistc.cpp:151
RecodeBeamSearch * search_
void Pen(Color color)
Definition: scrollview.cpp:726
integer coordinate
Definition: points.h:30
void TextAttributes(const char *font, int pixel_size, bool bold, bool italic, bool underlined)
Definition: scrollview.cpp:641