tesseract  4.00.00dev
lstmtrainer.cpp
Go to the documentation of this file.
1 // File: lstmtrainer.cpp
3 // Description: Top-level line trainer class for LSTM-based networks.
4 // Author: Ray Smith
5 // Created: Fir May 03 09:14:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 // Include automatically generated configuration file if running autoconf.
20 #ifdef HAVE_CONFIG_H
21 #include "config_auto.h"
22 #endif
23 
24 #include "lstmtrainer.h"
25 #include <string>
26 
27 #include "allheaders.h"
28 #include "boxread.h"
29 #include "ctc.h"
30 #include "imagedata.h"
31 #include "input.h"
32 #include "networkbuilder.h"
33 #include "ratngs.h"
34 #include "recodebeam.h"
35 #ifdef INCLUDE_TENSORFLOW
36 #include "tfnetwork.h"
37 #endif
38 #include "tprintf.h"
39 
40 #include "callcpp.h"
41 
42 namespace tesseract {
43 
44 // Min actual error rate increase to constitute divergence.
45 const double kMinDivergenceRate = 50.0;
46 // Min iterations since last best before acting on a stall.
47 const int kMinStallIterations = 10000;
48 // Fraction of current char error rate that sub_trainer_ has to be ahead
49 // before we declare the sub_trainer_ a success and switch to it.
50 const double kSubTrainerMarginFraction = 3.0 / 128;
51 // Factor to reduce learning rate on divergence.
52 const double kLearningRateDecay = sqrt(0.5);
53 // LR adjustment iterations.
54 const int kNumAdjustmentIterations = 100;
55 // How often to add data to the error_graph_.
56 const int kErrorGraphInterval = 1000;
57 // Number of training images to train between calls to MaintainCheckpoints.
58 const int kNumPagesPerBatch = 100;
59 // Min percent error rate to consider start-up phase over.
60 const int kMinStartedErrorRate = 75;
61 // Error rate at which to transition to stage 1.
62 const double kStageTransitionThreshold = 10.0;
63 // Confidence beyond which the truth is more likely wrong than the recognizer.
64 const double kHighConfidence = 0.9375; // 15/16.
65 // Fraction of weight sign-changing total to constitute a definite improvement.
66 const double kImprovementFraction = 15.0 / 16.0;
67 // Fraction of last written best to make it worth writing another.
68 const double kBestCheckpointFraction = 31.0 / 32.0;
69 // Scale factor for display of target activations of CTC.
70 const int kTargetXScale = 5;
71 const int kTargetYScale = 100;
72 
74  : training_data_(0),
75  file_reader_(LoadDataFromFile),
76  file_writer_(SaveDataToFile),
77  checkpoint_reader_(
78  NewPermanentTessCallback(this, &LSTMTrainer::ReadTrainingDump)),
79  checkpoint_writer_(
80  NewPermanentTessCallback(this, &LSTMTrainer::SaveTrainingDump)),
81  sub_trainer_(NULL) {
83  debug_interval_ = 0;
84 }
85 
87  CheckPointReader checkpoint_reader,
88  CheckPointWriter checkpoint_writer,
89  const char* model_base, const char* checkpoint_name,
90  int debug_interval, inT64 max_memory)
91  : training_data_(max_memory),
92  file_reader_(file_reader),
93  file_writer_(file_writer),
94  checkpoint_reader_(checkpoint_reader),
95  checkpoint_writer_(checkpoint_writer),
96  sub_trainer_(NULL) {
100  if (checkpoint_reader_ == NULL) {
103  }
104  if (checkpoint_writer_ == NULL) {
107  }
108  debug_interval_ = debug_interval;
109  model_base_ = model_base;
110  checkpoint_name_ = checkpoint_name;
111 }
112 
114  delete align_win_;
115  delete target_win_;
116  delete ctc_win_;
117  delete recon_win_;
118  delete checkpoint_reader_;
119  delete checkpoint_writer_;
120  delete sub_trainer_;
121 }
122 
123 // Tries to deserialize a trainer from the given file and silently returns
124 // false in case of failure.
126  GenericVector<char> data;
127  if (!(*file_reader_)(filename, &data)) return false;
128  tprintf("Loaded file %s, unpacking...\n", filename);
129  return checkpoint_reader_->Run(data, this);
130 }
131 
132 // Initializes the character set encode/decode mechanism.
133 // train_flags control training behavior according to the TrainingFlags
134 // enum, including character set encoding.
135 // script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided,
136 // fully initializes the unicharset from the universal unicharsets.
137 // Note: Call before InitNetwork!
138 void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset,
139  const STRING& script_dir, int train_flags) {
141  training_flags_ = train_flags;
142  ccutil_.unicharset.CopyFrom(unicharset);
144  : GetUnicharset().size();
145  SetUnicharsetProperties(script_dir);
146 }
147 
148 // Initializes the character set encode/decode mechanism directly from a
149 // previously setup UNICHARSET and UnicharCompress.
150 // ctc_mode controls how the truth text is mapped to the network targets.
151 // Note: Call before InitNetwork!
152 void LSTMTrainer::InitCharSet(const UNICHARSET& unicharset,
153  const UnicharCompress& recoder) {
155  int flags = TF_COMPRESS_UNICHARSET;
156  training_flags_ = static_cast<TrainingFlags>(flags);
157  ccutil_.unicharset.CopyFrom(unicharset);
158  recoder_ = recoder;
160  : GetUnicharset().size();
161  RecodedCharID code;
163  null_char_ = code(0);
164  // Space should encode as itself.
166  ASSERT_HOST(code(0) == UNICHAR_SPACE);
167 }
168 
169 // Initializes the trainer with a network_spec in the network description
170 // net_flags control network behavior according to the NetworkFlags enum.
171 // There isn't really much difference between them - only where the effects
172 // are implemented.
173 // For other args see NetworkBuilder::InitNetwork.
174 // Note: Be sure to call InitCharSet before InitNetwork!
175 bool LSTMTrainer::InitNetwork(const STRING& network_spec, int append_index,
176  int net_flags, float weight_range,
177  float learning_rate, float momentum) {
178  // Call after InitCharSet.
180  weight_range_ = weight_range;
182  momentum_ = momentum;
183  int num_outputs = null_char_ == GetUnicharset().size()
184  ? null_char_ + 1
185  : GetUnicharset().size();
186  if (IsRecoding()) num_outputs = recoder_.code_range();
187  if (!NetworkBuilder::InitNetwork(num_outputs, network_spec, append_index,
188  net_flags, weight_range, &randomizer_,
189  &network_)) {
190  return false;
191  }
192  network_str_ += network_spec;
193  tprintf("Built network:%s from request %s\n",
194  network_->spec().string(), network_spec.string());
195  tprintf("Training parameters:\n Debug interval = %d,"
196  " weights = %g, learning rate = %g, momentum=%g\n",
198  return true;
199 }
200 
201 // Initializes a trainer from a serialized TFNetworkModel proto.
202 // Returns the global step of TensorFlow graph or 0 if failed.
203 int LSTMTrainer::InitTensorFlowNetwork(const std::string& tf_proto) {
204 #ifdef INCLUDE_TENSORFLOW
205  delete network_;
206  TFNetwork* tf_net = new TFNetwork("TensorFlow");
207  training_iteration_ = tf_net->InitFromProtoStr(tf_proto);
208  if (training_iteration_ == 0) {
209  tprintf("InitFromProtoStr failed!!\n");
210  return 0;
211  }
212  network_ = tf_net;
213  ASSERT_HOST(recoder_.code_range() == tf_net->num_classes());
214  return training_iteration_;
215 #else
216  tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
217  return 0;
218 #endif
219 }
220 
221 // Resets all the iteration counters for fine tuning or traininng a head,
222 // where we want the error reporting to reset.
224  sample_iteration_ = 0;
228  best_error_rate_ = 100.0;
229  best_iteration_ = 0;
230  worst_error_rate_ = 0.0;
231  worst_iteration_ = 0;
234  perfect_delay_ = 0;
236  for (int i = 0; i < ET_COUNT; ++i) {
237  best_error_rates_[i] = 100.0;
238  worst_error_rates_[i] = 0.0;
240  error_rates_[i] = 100.0;
241  }
243 }
244 
245 // If the training sample is usable, grid searches for the optimal
246 // dict_ratio/cert_offset, and returns the results in a string of space-
247 // separated triplets of ratio,offset=worderr.
249  const ImageData* trainingdata, int iteration, double min_dict_ratio,
250  double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
251  double cert_offset_step, double max_cert_offset, STRING* results) {
252  sample_iteration_ = iteration;
253  NetworkIO fwd_outputs, targets;
254  Trainability result =
255  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
256  if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == NULL)
257  return result;
258 
259  // Encode/decode the truth to get the normalization.
260  GenericVector<int> truth_labels, ocr_labels, xcoords;
261  ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels));
262  // NO-dict error.
263  RecodeBeamSearch base_search(recoder_, null_char_, SimpleTextOutput(), NULL);
264  base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty,
265  NULL);
266  base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
267  STRING truth_text = DecodeLabels(truth_labels);
268  STRING ocr_text = DecodeLabels(ocr_labels);
269  double baseline_error = ComputeWordError(&truth_text, &ocr_text);
270  results->add_str_double("0,0=", baseline_error);
271 
273  for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
274  for (double c = min_cert_offset; c < max_cert_offset;
275  c += cert_offset_step) {
276  search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty, NULL);
277  search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
278  truth_text = DecodeLabels(truth_labels);
279  ocr_text = DecodeLabels(ocr_labels);
280  // This is destructive on both strings.
281  double word_error = ComputeWordError(&truth_text, &ocr_text);
282  if ((r == min_dict_ratio && c == min_cert_offset) ||
283  !std::isfinite(word_error)) {
284  STRING t = DecodeLabels(truth_labels);
285  STRING o = DecodeLabels(ocr_labels);
286  tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
287  t.string(), o.string(), word_error, truth_labels[0]);
288  }
289  results->add_str_double(" ", r);
290  results->add_str_double(",", c);
291  results->add_str_double("=", word_error);
292  }
293  }
294  return result;
295 }
296 
297 // Provides output on the distribution of weight values.
300 }
301 
302 // Loads a set of lstmf files that were created using the lstm.train config to
303 // tesseract into memory ready for training. Returns false if nothing was
304 // loaded.
308 }
309 
310 // Keeps track of best and locally worst char error_rate and launches tests
311 // using tester, when a new min or max is reached.
312 // Writes checkpoints at appropriate times and builds and returns a log message
313 // to indicate progress. Returns false if nothing interesting happened.
315  PrepareLogMsg(log_msg);
316  double error_rate = CharError();
317  int iteration = learning_iteration();
318  if (iteration >= stall_iteration_ &&
319  error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) &&
320  best_error_rate_ < kMinStartedErrorRate && !best_trainer_.empty()) {
321  // It hasn't got any better in a long while, and is a margin worse than the
322  // best, so go back to the best model and try a different learning rate.
323  StartSubtrainer(log_msg);
324  }
325  SubTrainerResult sub_trainer_result = STR_NONE;
326  if (sub_trainer_ != NULL) {
327  sub_trainer_result = UpdateSubtrainer(log_msg);
328  if (sub_trainer_result == STR_REPLACED) {
329  // Reset the inputs, as we have overwritten *this.
330  error_rate = CharError();
331  iteration = learning_iteration();
332  PrepareLogMsg(log_msg);
333  }
334  }
335  bool result = true; // Something interesting happened.
336  GenericVector<char> rec_model_data;
337  if (error_rate < best_error_rate_) {
338  SaveRecognitionDump(&rec_model_data);
339  log_msg->add_str_double(" New best char error = ", error_rate);
340  *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
341  // If sub_trainer_ is not NULL, either *this beat it to a new best, or it
342  // just overwrote *this. In either case, we have finished with it.
343  delete sub_trainer_;
344  sub_trainer_ = NULL;
346  if (TransitionTrainingStage(kStageTransitionThreshold)) {
347  log_msg->add_str_int(" Transitioned to stage ", CurrentTrainingStage());
348  }
350  if (error_rate < error_rate_of_last_saved_best_ * kBestCheckpointFraction) {
351  STRING best_model_name = DumpFilename();
352  if (!(*file_writer_)(best_trainer_, best_model_name)) {
353  *log_msg += " failed to write best model:";
354  } else {
355  *log_msg += " wrote best model:";
357  }
358  *log_msg += best_model_name;
359  }
360  } else if (error_rate > worst_error_rate_) {
361  SaveRecognitionDump(&rec_model_data);
362  log_msg->add_str_double(" New worst char error = ", error_rate);
363  *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
364  if (worst_error_rate_ > best_error_rate_ + kMinDivergenceRate &&
365  best_error_rate_ < kMinStartedErrorRate && !best_trainer_.empty()) {
366  // Error rate has ballooned. Go back to the best model.
367  *log_msg += "\nDivergence! ";
368  // Copy best_trainer_ before reading it, as it will get overwritten.
369  GenericVector<char> revert_data(best_trainer_);
370  if (checkpoint_reader_->Run(revert_data, this)) {
371  LogIterations("Reverted to", log_msg);
372  ReduceLearningRates(this, log_msg);
373  } else {
374  LogIterations("Failed to Revert at", log_msg);
375  }
376  // If it fails again, we will wait twice as long before reverting again.
377  stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
378  // Re-save the best trainer with the new learning rates and stall
379  // iteration.
381  }
382  } else {
383  // Something interesting happened only if the sub_trainer_ was trained.
384  result = sub_trainer_result != STR_NONE;
385  }
386  if (checkpoint_writer_ != NULL && file_writer_ != NULL &&
387  checkpoint_name_.length() > 0) {
388  // Write a current checkpoint.
389  GenericVector<char> checkpoint;
390  if (!checkpoint_writer_->Run(FULL, this, &checkpoint) ||
391  !(*file_writer_)(checkpoint, checkpoint_name_)) {
392  *log_msg += " failed to write checkpoint.";
393  } else {
394  *log_msg += " wrote checkpoint.";
395  }
396  }
397  *log_msg += "\n";
398  return result;
399 }
400 
401 // Builds a string containing a progress message with current error rates.
402 void LSTMTrainer::PrepareLogMsg(STRING* log_msg) const {
403  LogIterations("At", log_msg);
404  log_msg->add_str_double(", Mean rms=", error_rates_[ET_RMS]);
405  log_msg->add_str_double("%, delta=", error_rates_[ET_DELTA]);
406  log_msg->add_str_double("%, char train=", error_rates_[ET_CHAR_ERROR]);
407  log_msg->add_str_double("%, word train=", error_rates_[ET_WORD_RECERR]);
408  log_msg->add_str_double("%, skip ratio=", error_rates_[ET_SKIP_RATIO]);
409  *log_msg += "%, ";
410 }
411 
412 // Appends <intro_str> iteration learning_iteration()/training_iteration()/
413 // sample_iteration() to the log_msg.
414 void LSTMTrainer::LogIterations(const char* intro_str, STRING* log_msg) const {
415  *log_msg += intro_str;
416  log_msg->add_str_int(" iteration ", learning_iteration());
417  log_msg->add_str_int("/", training_iteration());
418  log_msg->add_str_int("/", sample_iteration());
419 }
420 
421 // Returns true and increments the training_stage_ if the error rate has just
422 // passed through the given threshold for the first time.
423 bool LSTMTrainer::TransitionTrainingStage(float error_threshold) {
424  if (best_error_rate_ < error_threshold &&
426  ++training_stage_;
427  return true;
428  }
429  return false;
430 }
431 
432 // Writes to the given file. Returns false in case of error.
433 bool LSTMTrainer::Serialize(TFile* fp) const {
434  if (!LSTMRecognizer::Serialize(fp)) return false;
435  if (fp->FWrite(&learning_iteration_, sizeof(learning_iteration_), 1) != 1)
436  return false;
437  if (fp->FWrite(&prev_sample_iteration_, sizeof(prev_sample_iteration_), 1) !=
438  1)
439  return false;
440  if (fp->FWrite(&perfect_delay_, sizeof(perfect_delay_), 1) != 1) return false;
442  sizeof(last_perfect_training_iteration_), 1) != 1)
443  return false;
444  for (int i = 0; i < ET_COUNT; ++i) {
445  if (!error_buffers_[i].Serialize(fp)) return false;
446  }
447  if (fp->FWrite(&error_rates_, sizeof(error_rates_), 1) != 1) return false;
448  if (fp->FWrite(&training_stage_, sizeof(training_stage_), 1) != 1)
449  return false;
450  uinT8 amount = serialize_amount_;
451  if (fp->FWrite(&amount, sizeof(amount), 1) != 1) return false;
452  if (amount == LIGHT) return true; // We are done.
453  if (fp->FWrite(&best_error_rate_, sizeof(best_error_rate_), 1) != 1)
454  return false;
455  if (fp->FWrite(&best_error_rates_, sizeof(best_error_rates_), 1) != 1)
456  return false;
457  if (fp->FWrite(&best_iteration_, sizeof(best_iteration_), 1) != 1)
458  return false;
459  if (fp->FWrite(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1)
460  return false;
461  if (fp->FWrite(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1)
462  return false;
463  if (fp->FWrite(&worst_iteration_, sizeof(worst_iteration_), 1) != 1)
464  return false;
465  if (fp->FWrite(&stall_iteration_, sizeof(stall_iteration_), 1) != 1)
466  return false;
467  if (!best_model_data_.Serialize(fp)) return false;
468  if (!worst_model_data_.Serialize(fp)) return false;
469  if (amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp)) return false;
470  GenericVector<char> sub_data;
471  if (sub_trainer_ != NULL && !SaveTrainingDump(LIGHT, sub_trainer_, &sub_data))
472  return false;
473  if (!sub_data.Serialize(fp)) return false;
474  if (!best_error_history_.Serialize(fp)) return false;
475  if (!best_error_iterations_.Serialize(fp)) return false;
476  if (fp->FWrite(&improvement_steps_, sizeof(improvement_steps_), 1) != 1)
477  return false;
478  return true;
479 }
480 
481 // Reads from the given file. Returns false in case of error.
482 // NOTE: It is assumed that the trainer is never read cross-endian.
484  if (!LSTMRecognizer::DeSerialize(fp)) return false;
485  if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) {
486  // Special case. If we successfully decoded the recognizer, but fail here
487  // then it means we were just given a recognizer, so issue a warning and
488  // allow it.
489  tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
492  return true;
493  }
495  1) != 1)
496  return false;
497  if (fp->FReadEndian(&perfect_delay_, sizeof(perfect_delay_), 1) != 1)
498  return false;
500  sizeof(last_perfect_training_iteration_), 1) != 1)
501  return false;
502  for (int i = 0; i < ET_COUNT; ++i) {
503  if (!error_buffers_[i].DeSerialize(fp)) return false;
504  }
505  if (fp->FRead(&error_rates_, sizeof(error_rates_), 1) != 1) return false;
506  if (fp->FReadEndian(&training_stage_, sizeof(training_stage_), 1) != 1)
507  return false;
508  uinT8 amount;
509  if (fp->FRead(&amount, sizeof(amount), 1) != 1) return false;
510  if (amount == LIGHT) return true; // Don't read the rest.
511  if (fp->FReadEndian(&best_error_rate_, sizeof(best_error_rate_), 1) != 1)
512  return false;
513  if (fp->FReadEndian(&best_error_rates_, sizeof(best_error_rates_), 1) != 1)
514  return false;
515  if (fp->FReadEndian(&best_iteration_, sizeof(best_iteration_), 1) != 1)
516  return false;
517  if (fp->FReadEndian(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1)
518  return false;
519  if (fp->FReadEndian(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1)
520  return false;
521  if (fp->FReadEndian(&worst_iteration_, sizeof(worst_iteration_), 1) != 1)
522  return false;
523  if (fp->FReadEndian(&stall_iteration_, sizeof(stall_iteration_), 1) != 1)
524  return false;
525  if (!best_model_data_.DeSerialize(fp)) return false;
526  if (!worst_model_data_.DeSerialize(fp)) return false;
527  if (amount != NO_BEST_TRAINER && !best_trainer_.DeSerialize(fp)) return false;
528  GenericVector<char> sub_data;
529  if (!sub_data.DeSerialize(fp)) return false;
530  delete sub_trainer_;
531  if (sub_data.empty()) {
532  sub_trainer_ = NULL;
533  } else {
534  sub_trainer_ = new LSTMTrainer();
535  if (!ReadTrainingDump(sub_data, sub_trainer_)) return false;
536  }
537  if (!best_error_history_.DeSerialize(fp)) return false;
538  if (!best_error_iterations_.DeSerialize(fp)) return false;
539  if (fp->FReadEndian(&improvement_steps_, sizeof(improvement_steps_), 1) != 1)
540  return false;
541  return true;
542 }
543 
544 // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
545 // learning rates (by scaling reduction, or layer specific, according to
546 // NF_LAYER_SPECIFIC_LR).
548  delete sub_trainer_;
549  sub_trainer_ = new LSTMTrainer();
551  *log_msg += " Failed to revert to previous best for trial!";
552  delete sub_trainer_;
553  sub_trainer_ = NULL;
554  } else {
555  log_msg->add_str_int(" Trial sub_trainer_ from iteration ",
557  // Reduce learning rate so it doesn't diverge this time.
558  sub_trainer_->ReduceLearningRates(this, log_msg);
559  // If it fails again, we will wait twice as long before reverting again.
560  int stall_offset =
562  stall_iteration_ = learning_iteration() + 2 * stall_offset;
564  // Re-save the best trainer with the new learning rates and stall iteration.
566  }
567 }
568 
569 // While the sub_trainer_ is behind the current training iteration and its
570 // training error is at least kSubTrainerMarginFraction better than the
571 // current training error, trains the sub_trainer_, and returns STR_UPDATED if
572 // it did anything. If it catches up, and has a better error rate than the
573 // current best, as well as a margin over the current error rate, then the
574 // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
575 // returned. STR_NONE is returned if the subtrainer wasn't good enough to
576 // receive any training iterations.
578  double training_error = CharError();
579  double sub_error = sub_trainer_->CharError();
580  double sub_margin = (training_error - sub_error) / sub_error;
581  if (sub_margin >= kSubTrainerMarginFraction) {
582  log_msg->add_str_double(" sub_trainer=", sub_error);
583  log_msg->add_str_double(" margin=", 100.0 * sub_margin);
584  *log_msg += "\n";
585  // Catch up to current iteration.
586  int end_iteration = training_iteration();
587  while (sub_trainer_->training_iteration() < end_iteration &&
588  sub_margin >= kSubTrainerMarginFraction) {
589  int target_iteration =
591  while (sub_trainer_->training_iteration() < target_iteration) {
592  sub_trainer_->TrainOnLine(this, false);
593  }
594  STRING batch_log = "Sub:";
595  sub_trainer_->PrepareLogMsg(&batch_log);
596  batch_log += "\n";
597  tprintf("UpdateSubtrainer:%s", batch_log.string());
598  *log_msg += batch_log;
599  sub_error = sub_trainer_->CharError();
600  sub_margin = (training_error - sub_error) / sub_error;
601  }
602  if (sub_error < best_error_rate_ &&
603  sub_margin >= kSubTrainerMarginFraction) {
604  // The sub_trainer_ has won the race to a new best. Switch to it.
605  GenericVector<char> updated_trainer;
606  SaveTrainingDump(LIGHT, sub_trainer_, &updated_trainer);
607  ReadTrainingDump(updated_trainer, this);
608  log_msg->add_str_int(" Sub trainer wins at iteration ",
610  *log_msg += "\n";
611  return STR_REPLACED;
612  }
613  return STR_UPDATED;
614  }
615  return STR_NONE;
616 }
617 
618 // Reduces network learning rates, either for everything, or for layers
619 // independently, according to NF_LAYER_SPECIFIC_LR.
621  STRING* log_msg) {
623  int num_reduced = ReduceLayerLearningRates(
624  kLearningRateDecay, kNumAdjustmentIterations, samples_trainer);
625  log_msg->add_str_int("\nReduced learning rate on layers: ", num_reduced);
626  } else {
627  ScaleLearningRate(kLearningRateDecay);
628  log_msg->add_str_double("\nReduced learning rate to :", learning_rate_);
629  }
630  *log_msg += "\n";
631 }
632 
633 // Considers reducing the learning rate independently for each layer down by
634 // factor(<1), or leaving it the same, by double-training the given number of
635 // samples and minimizing the amount of changing of sign of weight updates.
636 // Even if it looks like all weights should remain the same, an adjustment
637 // will be made to guarantee a different result when reverting to an old best.
638 // Returns the number of layer learning rates that were reduced.
639 int LSTMTrainer::ReduceLayerLearningRates(double factor, int num_samples,
640  LSTMTrainer* samples_trainer) {
641  enum WhichWay {
642  LR_DOWN, // Learning rate will go down by factor.
643  LR_SAME, // Learning rate will stay the same.
644  LR_COUNT // Size of arrays.
645  };
646  // Epsilon is so small that it may as well be zero, but still positive.
647  const double kEpsilon = 1.0e-30;
649  int num_layers = layers.size();
650  GenericVector<int> num_weights;
651  num_weights.init_to_size(num_layers, 0);
652  GenericVector<double> bad_sums[LR_COUNT];
653  GenericVector<double> ok_sums[LR_COUNT];
654  for (int i = 0; i < LR_COUNT; ++i) {
655  bad_sums[i].init_to_size(num_layers, 0.0);
656  ok_sums[i].init_to_size(num_layers, 0.0);
657  }
658  double momentum_factor = 1.0 / (1.0 - momentum_);
659  GenericVector<char> orig_trainer;
660  SaveTrainingDump(LIGHT, this, &orig_trainer);
661  for (int i = 0; i < num_layers; ++i) {
662  Network* layer = GetLayer(layers[i]);
663  num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
664  }
665  int iteration = sample_iteration();
666  for (int s = 0; s < num_samples; ++s) {
667  // Which way will we modify the learning rate?
668  for (int ww = 0; ww < LR_COUNT; ++ww) {
669  // Transfer momentum to learning rate and adjust by the ww factor.
670  float ww_factor = momentum_factor;
671  if (ww == LR_DOWN) ww_factor *= factor;
672  // Make a copy of *this, so we can mess about without damaging anything.
673  LSTMTrainer copy_trainer;
674  copy_trainer.ReadTrainingDump(orig_trainer, &copy_trainer);
675  // Clear the updates, doing nothing else.
676  copy_trainer.network_->Update(0.0, 0.0, 0);
677  // Adjust the learning rate in each layer.
678  for (int i = 0; i < num_layers; ++i) {
679  if (num_weights[i] == 0) continue;
680  copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor);
681  }
682  copy_trainer.SetIteration(iteration);
683  // Train on the sample, but keep the update in updates_ instead of
684  // applying to the weights.
685  const ImageData* trainingdata =
686  copy_trainer.TrainOnLine(samples_trainer, true);
687  if (trainingdata == NULL) continue;
688  // We'll now use this trainer again for each layer.
689  GenericVector<char> updated_trainer;
690  SaveTrainingDump(LIGHT, &copy_trainer, &updated_trainer);
691  for (int i = 0; i < num_layers; ++i) {
692  if (num_weights[i] == 0) continue;
693  LSTMTrainer layer_trainer;
694  layer_trainer.ReadTrainingDump(updated_trainer, &layer_trainer);
695  Network* layer = layer_trainer.GetLayer(layers[i]);
696  // Update the weights in just the layer, and also zero the updates
697  // matrix (to epsilon).
698  layer->Update(0.0, kEpsilon, 0);
699  // Train again on the same sample, again holding back the updates.
700  layer_trainer.TrainOnLine(trainingdata, true);
701  // Count the sign changes in the updates in layer vs in copy_trainer.
702  float before_bad = bad_sums[ww][i];
703  float before_ok = ok_sums[ww][i];
704  layer->CountAlternators(*copy_trainer.GetLayer(layers[i]),
705  &ok_sums[ww][i], &bad_sums[ww][i]);
706  float bad_frac =
707  bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
708  if (bad_frac > 0.0f)
709  bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
710  }
711  }
712  ++iteration;
713  }
714  int num_lowered = 0;
715  for (int i = 0; i < num_layers; ++i) {
716  if (num_weights[i] == 0) continue;
717  Network* layer = GetLayer(layers[i]);
718  float lr = GetLayerLearningRate(layers[i]);
719  double total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
720  double total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
721  double frac_down = bad_sums[LR_DOWN][i] / total_down;
722  double frac_same = bad_sums[LR_SAME][i] / total_same;
723  tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().string(),
724  lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
725  if (frac_down < frac_same * kImprovementFraction) {
726  tprintf(" REDUCED\n");
727  ScaleLayerLearningRate(layers[i], factor);
728  ++num_lowered;
729  } else {
730  tprintf(" SAME\n");
731  }
732  }
733  if (num_lowered == 0) {
734  // Just lower everything to make sure.
735  for (int i = 0; i < num_layers; ++i) {
736  if (num_weights[i] > 0) {
737  ScaleLayerLearningRate(layers[i], factor);
738  ++num_lowered;
739  }
740  }
741  }
742  return num_lowered;
743 }
744 
745 // Converts the string to integer class labels, with appropriate null_char_s
746 // in between if not in SimpleTextOutput mode. Returns false on failure.
747 /* static */
748 bool LSTMTrainer::EncodeString(const STRING& str, const UNICHARSET& unicharset,
749  const UnicharCompress* recoder, bool simple_text,
750  int null_char, GenericVector<int>* labels) {
751  if (str.string() == NULL || str.length() <= 0) {
752  tprintf("Empty truth string!\n");
753  return false;
754  }
755  int err_index;
756  GenericVector<int> internal_labels;
757  labels->truncate(0);
758  if (!simple_text) labels->push_back(null_char);
759  if (unicharset.encode_string(str.string(), true, &internal_labels, NULL,
760  &err_index)) {
761  bool success = true;
762  for (int i = 0; i < internal_labels.size(); ++i) {
763  if (recoder != NULL) {
764  // Re-encode labels via recoder.
765  RecodedCharID code;
766  int len = recoder->EncodeUnichar(internal_labels[i], &code);
767  if (len > 0) {
768  for (int j = 0; j < len; ++j) {
769  labels->push_back(code(j));
770  if (!simple_text) labels->push_back(null_char);
771  }
772  } else {
773  success = false;
774  err_index = 0;
775  break;
776  }
777  } else {
778  labels->push_back(internal_labels[i]);
779  if (!simple_text) labels->push_back(null_char);
780  }
781  }
782  if (success) return true;
783  }
784  tprintf("Encoding of string failed! Failure bytes:");
785  while (err_index < str.length()) {
786  tprintf(" %x", str[err_index++]);
787  }
788  tprintf("\n");
789  return false;
790 }
791 
792 // Performs forward-backward on the given trainingdata.
793 // Returns a Trainability enum to indicate the suitability of the sample.
795  bool batch) {
796  NetworkIO fwd_outputs, targets;
797  Trainability trainable =
798  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
800  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
801  return trainable; // Sample was unusable.
802  }
803  bool debug = debug_interval_ > 0 &&
805  // Run backprop on the output.
806  NetworkIO bp_deltas;
807  if (network_->IsTraining() &&
808  (trainable != PERFECT ||
811  network_->Backward(debug, targets, &scratch_space_, &bp_deltas);
812  network_->Update(learning_rate_, batch ? -1.0f : momentum_,
813  training_iteration_ + 1);
814  }
815 #ifndef GRAPHICS_DISABLED
816  if (debug_interval_ == 1 && debug_win_ != NULL) {
818  }
819 #endif // GRAPHICS_DISABLED
820  // Roll the memory of past means.
822  return trainable;
823 }
824 
825 // Prepares the ground truth, runs forward, and prepares the targets.
826 // Returns a Trainability enum to indicate the suitability of the sample.
828  NetworkIO* fwd_outputs,
829  NetworkIO* targets) {
830  if (trainingdata == NULL) {
831  tprintf("Null trainingdata.\n");
832  return UNENCODABLE;
833  }
834  // Ensure repeatability of random elements even across checkpoints.
835  bool debug = debug_interval_ > 0 &&
837  GenericVector<int> truth_labels;
838  if (!EncodeString(trainingdata->transcription(), &truth_labels)) {
839  tprintf("Can't encode transcription: %s\n",
840  trainingdata->transcription().string());
841  return UNENCODABLE;
842  }
843  int w = 0;
844  while (w < truth_labels.size() &&
845  (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_))
846  ++w;
847  if (w == truth_labels.size()) {
848  tprintf("Blank transcription: %s\n",
849  trainingdata->transcription().string());
850  return UNENCODABLE;
851  }
852  float image_scale;
853  NetworkIO inputs;
854  bool invert = trainingdata->boxes().empty();
855  if (!RecognizeLine(*trainingdata, invert, debug, invert, 0.0f, &image_scale,
856  &inputs, fwd_outputs)) {
857  tprintf("Image not trainable\n");
858  return UNENCODABLE;
859  }
860  targets->Resize(*fwd_outputs, network_->NumOutputs());
861  LossType loss_type = OutputLossType();
862  if (loss_type == LT_SOFTMAX) {
863  if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) {
864  tprintf("Compute simple targets failed!\n");
865  return UNENCODABLE;
866  }
867  } else if (loss_type == LT_CTC) {
868  if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) {
869  tprintf("Compute CTC targets failed!\n");
870  return UNENCODABLE;
871  }
872  } else {
873  tprintf("Logistic outputs not implemented yet!\n");
874  return UNENCODABLE;
875  }
876  GenericVector<int> ocr_labels;
877  GenericVector<int> xcoords;
878  LabelsFromOutputs(*fwd_outputs, 0.0f, &ocr_labels, &xcoords);
879  // CTC does not produce correct target labels to begin with.
880  if (loss_type != LT_CTC) {
881  LabelsFromOutputs(*targets, 0.0f, &truth_labels, &xcoords);
882  }
883  if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels,
884  *targets)) {
885  tprintf("Input width was %d\n", inputs.Width());
886  return UNENCODABLE;
887  }
888  STRING ocr_text = DecodeLabels(ocr_labels);
889  STRING truth_text = DecodeLabels(truth_labels);
890  targets->SubtractAllFromFloat(*fwd_outputs);
891  if (debug_interval_ != 0) {
892  tprintf("Iteration %d: BEST OCR TEXT : %s\n", training_iteration(),
893  ocr_text.string());
894  }
895  double char_error = ComputeCharError(truth_labels, ocr_labels);
896  double word_error = ComputeWordError(&truth_text, &ocr_text);
897  double delta_error = ComputeErrorRates(*targets, char_error, word_error);
898  if (debug_interval_ != 0) {
899  tprintf("File %s page %d %s:\n", trainingdata->imagefilename().string(),
900  trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : "");
901  }
902  if (delta_error == 0.0) return PERFECT;
903  if (targets->AnySuspiciousTruth(kHighConfidence)) return HI_PRECISION_ERR;
904  return TRAINABLE;
905 }
906 
907 // Writes the trainer to memory, so that the current training state can be
908 // restored.
910  const LSTMTrainer* trainer,
911  GenericVector<char>* data) const {
912  TFile fp;
913  fp.OpenWrite(data);
914  trainer->serialize_amount_ = serialize_amount;
915  return trainer->Serialize(&fp);
916 }
917 
918 // Reads previously saved trainer from memory.
920  LSTMTrainer* trainer) {
921  return trainer->ReadSizedTrainingDump(&data[0], data.size());
922 }
923 
924 bool LSTMTrainer::ReadSizedTrainingDump(const char* data, int size) {
925  TFile fp;
926  fp.Open(data, size);
927  return DeSerialize(&fp);
928 }
929 
930 // Writes the recognizer to memory, so that it can be used for testing later.
932  TFile fp;
933  fp.OpenWrite(data);
937 }
938 
939 // Reads and returns a previously saved recognizer from memory.
941  const GenericVector<char>& data) {
942  TFile fp;
943  fp.Open(&data[0], data.size());
944  LSTMRecognizer* recognizer = new LSTMRecognizer;
945  ASSERT_HOST(recognizer->DeSerialize(&fp));
946  return recognizer;
947 }
948 
949 // Returns a suitable filename for a training dump, based on the model_base_,
950 // the iteration and the error rates.
954  filename.add_str_int("_", best_iteration_);
955  filename += ".lstm";
956  return filename;
957 }
958 
959 // Fills the whole error buffer of the given type with the given value.
960 void LSTMTrainer::FillErrorBuffer(double new_error, ErrorTypes type) {
961  for (int i = 0; i < kRollingBufferSize_; ++i)
962  error_buffers_[type][i] = new_error;
963  error_rates_[type] = 100.0 * new_error;
964 }
965 
966 // Factored sub-constructor sets up reasonable default values.
968  align_win_ = NULL;
969  target_win_ = NULL;
970  ctc_win_ = NULL;
971  recon_win_ = NULL;
974  training_stage_ = 0;
976  InitIterations();
977 }
978 
979 // Sets the unicharset properties using the given script_dir as a source of
980 // script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets
981 // up the recoder_ to simplify the unicharset.
983  tprintf("Setting unichar properties\n");
984  for (int s = 0; s < GetUnicharset().get_script_table_size(); ++s) {
985  if (strcmp("NULL", GetUnicharset().get_script_from_script_id(s)) == 0)
986  continue;
987  // Load the unicharset for the script if available.
988  STRING filename = script_dir + "/" +
990  ".unicharset";
991  UNICHARSET script_set;
992  GenericVector<char> data;
993  if ((*file_reader_)(filename, &data) &&
994  script_set.load_from_inmemory_file(&data[0], data.size())) {
995  tprintf("Setting properties for script %s\n",
996  GetUnicharset().get_script_from_script_id(s));
998  }
999  }
1000  if (IsRecoding()) {
1001  STRING filename = script_dir + "/radical-stroke.txt";
1002  GenericVector<char> data;
1003  if ((*file_reader_)(filename, &data)) {
1004  data += '\0';
1005  STRING stroke_table = &data[0];
1007  &stroke_table)) {
1008  RecodedCharID code;
1010  null_char_ = code(0);
1011  // Space should encode as itself.
1013  ASSERT_HOST(code(0) == UNICHAR_SPACE);
1014  return;
1015  }
1016  } else {
1017  tprintf("Failed to load radical-stroke info from: %s\n",
1018  filename.string());
1019  }
1021  }
1022 }
1023 
1024 // Outputs the string and periodically displays the given network inputs
1025 // as an image in the given window, and the corresponding labels at the
1026 // corresponding x_starts.
1027 // Returns false if the truth string is empty.
1029  const ImageData& trainingdata,
1030  const NetworkIO& fwd_outputs,
1031  const GenericVector<int>& truth_labels,
1032  const NetworkIO& outputs) {
1033  const STRING& truth_text = DecodeLabels(truth_labels);
1034  if (truth_text.string() == NULL || truth_text.length() <= 0) {
1035  tprintf("Empty truth string at decode time!\n");
1036  return false;
1037  }
1038  if (debug_interval_ != 0) {
1039  // Get class labels, xcoords and string.
1040  GenericVector<int> labels;
1041  GenericVector<int> xcoords;
1042  LabelsFromOutputs(outputs, 0.0f, &labels, &xcoords);
1043  STRING text = DecodeLabels(labels);
1044  tprintf("Iteration %d: ALIGNED TRUTH : %s\n",
1045  training_iteration(), text.string());
1046  if (debug_interval_ > 0 && training_iteration() % debug_interval_ == 0) {
1047  tprintf("TRAINING activation path for truth string %s\n",
1048  truth_text.string());
1049  DebugActivationPath(outputs, labels, xcoords);
1050  DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_);
1051  if (OutputLossType() == LT_CTC) {
1052  DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_);
1053  DisplayTargets(outputs, "CTC Targets", &target_win_);
1054  }
1055  }
1056  }
1057  return true;
1058 }
1059 
1060 // Displays the network targets as line a line graph.
1062  const char* window_name, ScrollView** window) {
1063 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics.
1064  int width = targets.Width();
1065  int num_features = targets.NumFeatures();
1066  Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale,
1067  window);
1068  for (int c = 0; c < num_features; ++c) {
1069  int color = c % (ScrollView::GREEN_YELLOW - 1) + 2;
1070  (*window)->Pen(static_cast<ScrollView::Color>(color));
1071  int start_t = -1;
1072  for (int t = 0; t < width; ++t) {
1073  double target = targets.f(t)[c];
1074  target *= kTargetYScale;
1075  if (target >= 1) {
1076  if (start_t < 0) {
1077  (*window)->SetCursor(t - 1, 0);
1078  start_t = t;
1079  }
1080  (*window)->DrawTo(t, target);
1081  } else if (start_t >= 0) {
1082  (*window)->DrawTo(t, 0);
1083  (*window)->DrawTo(start_t - 1, 0);
1084  start_t = -1;
1085  }
1086  }
1087  if (start_t >= 0) {
1088  (*window)->DrawTo(width, 0);
1089  (*window)->DrawTo(start_t - 1, 0);
1090  }
1091  }
1092  (*window)->Update();
1093 #endif // GRAPHICS_DISABLED
1094 }
1095 
1096 // Builds a no-compromises target where the first positions should be the
1097 // truth labels and the rest is padded with the null_char_.
1099  const GenericVector<int>& truth_labels,
1100  NetworkIO* targets) {
1101  if (truth_labels.size() > targets->Width()) {
1102  tprintf("Error: transcription %s too long to fit into target of width %d\n",
1103  DecodeLabels(truth_labels).string(), targets->Width());
1104  return false;
1105  }
1106  for (int i = 0; i < truth_labels.size() && i < targets->Width(); ++i) {
1107  targets->SetActivations(i, truth_labels[i], 1.0);
1108  }
1109  for (int i = truth_labels.size(); i < targets->Width(); ++i) {
1110  targets->SetActivations(i, null_char_, 1.0);
1111  }
1112  return true;
1113 }
1114 
1115 // Builds a target using standard CTC. truth_labels should be pre-padded with
1116 // nulls wherever desired. They don't have to be between all labels.
1117 // outputs is input-output, as it gets clipped to minimum probability.
1119  NetworkIO* outputs, NetworkIO* targets) {
1120  // Bottom-clip outputs to a minimum probability.
1121  CTC::NormalizeProbs(outputs);
1122  return CTC::ComputeCTCTargets(truth_labels, null_char_,
1123  outputs->float_array(), targets);
1124 }
1125 
1126 // Computes network errors, and stores the results in the rolling buffers,
1127 // along with the supplied text_error.
1128 // Returns the delta error of the current sample (not running average.)
1130  double char_error, double word_error) {
1132  // Delta error is the fraction of timesteps with >0.5 error in the top choice
1133  // score. If zero, then the top choice characters are guaranteed correct,
1134  // even when there is residue in the RMS error.
1135  double delta_error = ComputeWinnerError(deltas);
1136  UpdateErrorBuffer(delta_error, ET_DELTA);
1137  UpdateErrorBuffer(word_error, ET_WORD_RECERR);
1138  UpdateErrorBuffer(char_error, ET_CHAR_ERROR);
1139  // Skip ratio measures the difference between sample_iteration_ and
1140  // training_iteration_, which reflects the number of unusable samples,
1141  // usually due to unencodable truth text, or the text not fitting in the
1142  // space for the output.
1143  double skip_count = sample_iteration_ - prev_sample_iteration_;
1144  UpdateErrorBuffer(skip_count, ET_SKIP_RATIO);
1145  return delta_error;
1146 }
1147 
1148 // Computes the network activation RMS error rate.
1150  double total_error = 0.0;
1151  int width = deltas.Width();
1152  int num_classes = deltas.NumFeatures();
1153  for (int t = 0; t < width; ++t) {
1154  const float* class_errs = deltas.f(t);
1155  for (int c = 0; c < num_classes; ++c) {
1156  double error = class_errs[c];
1157  total_error += error * error;
1158  }
1159  }
1160  return sqrt(total_error / (width * num_classes));
1161 }
1162 
1163 // Computes network activation winner error rate. (Number of values that are
1164 // in error by >= 0.5 divided by number of time-steps.) More closely related
1165 // to final character error than RMS, but still directly calculable from
1166 // just the deltas. Because of the binary nature of the targets, zero winner
1167 // error is a sufficient but not necessary condition for zero char error.
1169  int num_errors = 0;
1170  int width = deltas.Width();
1171  int num_classes = deltas.NumFeatures();
1172  for (int t = 0; t < width; ++t) {
1173  const float* class_errs = deltas.f(t);
1174  for (int c = 0; c < num_classes; ++c) {
1175  float abs_delta = fabs(class_errs[c]);
1176  // TODO(rays) Filtering cases where the delta is very large to cut out
1177  // GT errors doesn't work. Find a better way or get better truth.
1178  if (0.5 <= abs_delta)
1179  ++num_errors;
1180  }
1181  }
1182  return static_cast<double>(num_errors) / width;
1183 }
1184 
1185 // Computes a very simple bag of chars char error rate.
1187  const GenericVector<int>& ocr_str) {
1188  GenericVector<int> label_counts;
1189  label_counts.init_to_size(NumOutputs(), 0);
1190  int truth_size = 0;
1191  for (int i = 0; i < truth_str.size(); ++i) {
1192  if (truth_str[i] != null_char_) {
1193  ++label_counts[truth_str[i]];
1194  ++truth_size;
1195  }
1196  }
1197  for (int i = 0; i < ocr_str.size(); ++i) {
1198  if (ocr_str[i] != null_char_) {
1199  --label_counts[ocr_str[i]];
1200  }
1201  }
1202  int char_errors = 0;
1203  for (int i = 0; i < label_counts.size(); ++i) {
1204  char_errors += abs(label_counts[i]);
1205  }
1206  if (truth_size == 0) {
1207  return (char_errors == 0) ? 0.0 : 1.0;
1208  }
1209  return static_cast<double>(char_errors) / truth_size;
1210 }
1211 
1212 // Computes word recall error rate using a very simple bag of words algorithm.
1213 // NOTE that this is destructive on both input strings.
1214 double LSTMTrainer::ComputeWordError(STRING* truth_str, STRING* ocr_str) {
1215  typedef std::unordered_map<std::string, int, std::hash<std::string> > StrMap;
1216  GenericVector<STRING> truth_words, ocr_words;
1217  truth_str->split(' ', &truth_words);
1218  if (truth_words.empty()) return 0.0;
1219  ocr_str->split(' ', &ocr_words);
1220  StrMap word_counts;
1221  for (int i = 0; i < truth_words.size(); ++i) {
1222  std::string truth_word(truth_words[i].string());
1223  StrMap::iterator it = word_counts.find(truth_word);
1224  if (it == word_counts.end())
1225  word_counts.insert(std::make_pair(truth_word, 1));
1226  else
1227  ++it->second;
1228  }
1229  for (int i = 0; i < ocr_words.size(); ++i) {
1230  std::string ocr_word(ocr_words[i].string());
1231  StrMap::iterator it = word_counts.find(ocr_word);
1232  if (it == word_counts.end())
1233  word_counts.insert(std::make_pair(ocr_word, -1));
1234  else
1235  --it->second;
1236  }
1237  int word_recall_errs = 0;
1238  for (StrMap::const_iterator it = word_counts.begin(); it != word_counts.end();
1239  ++it) {
1240  if (it->second > 0) word_recall_errs += it->second;
1241  }
1242  return static_cast<double>(word_recall_errs) / truth_words.size();
1243 }
1244 
1245 // Updates the error buffer and corresponding mean of the given type with
1246 // the new_error.
1247 void LSTMTrainer::UpdateErrorBuffer(double new_error, ErrorTypes type) {
1249  error_buffers_[type][index] = new_error;
1250  // Compute the mean error.
1251  int mean_count = MIN(training_iteration_ + 1, error_buffers_[type].size());
1252  double buffer_sum = 0.0;
1253  for (int i = 0; i < mean_count; ++i) buffer_sum += error_buffers_[type][i];
1254  double mean = buffer_sum / mean_count;
1255  // Trim precision to 1/1000 of 1%.
1256  error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0;
1257 }
1258 
1259 // Rolls error buffers and reports the current means.
1262  if (NewSingleError(ET_DELTA) > 0.0)
1264  else
1267  if (debug_interval_ != 0) {
1268  tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1272  }
1273 }
1274 
1275 // Given that error_rate is either a new min or max, updates the best/worst
1276 // error rates, and record of progress.
1277 // Tester is an externally supplied callback function that tests on some
1278 // data set with a given model and records the error rates in a graph.
1279 STRING LSTMTrainer::UpdateErrorGraph(int iteration, double error_rate,
1280  const GenericVector<char>& model_data,
1281  TestCallback tester) {
1282  if (error_rate > best_error_rate_
1283  && iteration < best_iteration_ + kErrorGraphInterval) {
1284  // Too soon to record a new point.
1285  if (tester != NULL)
1286  return tester->Run(worst_iteration_, NULL, worst_model_data_,
1288  else
1289  return "";
1290  }
1291  STRING result;
1292  // NOTE: there are 2 asymmetries here:
1293  // 1. We are computing the global minimum, but the local maximum in between.
1294  // 2. If the tester returns an empty string, indicating that it is busy,
1295  // call it repeatedly on new local maxima to test the previous min, but
1296  // not the other way around, as there is little point testing the maxima
1297  // between very frequent minima.
1298  if (error_rate < best_error_rate_) {
1299  // This is a new (global) minimum.
1300  if (tester != NULL) {
1301  result = tester->Run(worst_iteration_, worst_error_rates_,
1304  best_model_data_ = model_data;
1305  }
1306  best_error_rate_ = error_rate;
1307  memcpy(best_error_rates_, error_rates_, sizeof(error_rates_));
1308  best_iteration_ = iteration;
1309  best_error_history_.push_back(error_rate);
1310  best_error_iterations_.push_back(iteration);
1311  // Compute 2% decay time.
1312  double two_percent_more = error_rate + 2.0;
1313  int i;
1314  for (i = best_error_history_.size() - 1;
1315  i >= 0 && best_error_history_[i] < two_percent_more; --i) {
1316  }
1317  int old_iteration = i >= 0 ? best_error_iterations_[i] : 0;
1318  improvement_steps_ = iteration - old_iteration;
1319  tprintf("2 Percent improvement time=%d, best error was %g @ %d\n",
1320  improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0,
1321  old_iteration);
1322  } else if (error_rate > best_error_rate_) {
1323  // This is a new (local) maximum.
1324  if (tester != NULL) {
1325  if (best_model_data_.empty()) {
1326  // Allow for multiple data points with "worst" error rate.
1327  result = tester->Run(worst_iteration_, worst_error_rates_,
1329  } else {
1330  result = tester->Run(best_iteration_, best_error_rates_,
1332  }
1333  if (result.length() > 0)
1335  worst_model_data_ = model_data;
1336  }
1337  }
1338  worst_error_rate_ = error_rate;
1339  memcpy(worst_error_rates_, error_rates_, sizeof(error_rates_));
1340  worst_iteration_ = iteration;
1341  return result;
1342 }
1343 
1344 } // namespace tesseract.
bool DeSerialize(bool swap, FILE *fp)
void PrepareLogMsg(STRING *log_msg) const
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
float GetLayerLearningRate(const STRING &id) const
int CurrentTrainingStage() const
Definition: lstmtrainer.h:213
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:441
int num_weights() const
Definition: network.h:119
ScrollView * ctc_win_
Definition: lstmtrainer.h:394
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
const int kTargetYScale
Definition: lstmtrainer.cpp:71
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:419
bool AnySuspiciousTruth(float confidence_thr) const
Definition: networkio.cpp:579
int Width() const
Definition: networkio.h:107
int64_t inT64
Definition: host.h:40
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
float * f(int t)
Definition: networkio.h:115
_ConstTessMemberResultCallback_0_0< false, R, T1 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)() const)
Definition: tesscallback.h:116
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:475
bool SaveDataToFile(const GenericVector< char > &data, const STRING &filename)
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, bool use_alternates, const UNICHARSET *target_unicharset, const TBOX &line_box, float score_ratio, bool one_word, PointerVector< WERD_RES > *words)
virtual void Update(float learning_rate, float momentum, int num_samples)
Definition: network.h:218
void init_to_size(int size, T t)
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
void InitCharSet(const UNICHARSET &unicharset, const STRING &script_dir, int train_flags)
virtual bool Serialize(TFile *fp) const
static const float kMinCertainty
Definition: recodebeam.h:213
LossType OutputLossType() const
double ComputeWinnerError(const NetworkIO &deltas)
int page_number() const
Definition: imagedata.h:130
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:473
bool(* FileReader)(const STRING &filename, GenericVector< char > *data)
SerializeAmount serialize_amount_
Definition: lstmtrainer.h:408
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:451
const int kMinStallIterations
Definition: lstmtrainer.cpp:47
const double kMinDivergenceRate
Definition: lstmtrainer.cpp:45
voidpf void uLong size
Definition: ioapi.h:39
void SubtractAllFromFloat(const NetworkIO &src)
Definition: networkio.cpp:824
virtual R Run(A1, A2)=0
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum)
void SetActivations(int t, int label, float ok_score)
Definition: networkio.cpp:537
const char * get_script_from_script_id(int id) const
Definition: unicharset.h:814
int push_back(T object)
const double kHighConfidence
Definition: lstmtrainer.cpp:64
const int kNumAdjustmentIterations
Definition: lstmtrainer.cpp:54
const UNICHARSET & GetUnicharset() const
#define tprintf(...)
Definition: tprintf.h:31
void SetPropertiesFromOther(const UNICHARSET &src)
Definition: unicharset.h:505
NetworkScratch scratch_space_
static const int kRollingBufferSize_
Definition: lstmtrainer.h:472
const char * string() const
Definition: strngs.cpp:198
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
ScrollView * target_win_
Definition: lstmtrainer.h:392
void ScaleLearningRate(double factor)
bool IsTraining() const
Definition: network.h:115
int FReadEndian(void *buffer, int size, int count)
Definition: serialis.cpp:97
bool empty() const
Definition: genericvector.h:90
void truncate(int size)
void LogIterations(const char *intro_str, STRING *log_msg) const
bool TryLoadingCheckpoint(const char *filename)
GenericVector< STRING > EnumerateLayers() const
void ExtractBestPathAsLabels(GenericVector< int > *labels, GenericVector< int > *xcoords) const
Definition: recodebeam.cpp:100
const double kImprovementFraction
Definition: lstmtrainer.cpp:66
inT32 length() const
Definition: strngs.cpp:193
int IntCastRounded(double x)
Definition: helpers.h:179
const double kSubTrainerMarginFraction
Definition: lstmtrainer.cpp:50
int size() const
Definition: genericvector.h:72
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
DocumentCache training_data_
Definition: lstmtrainer.h:406
const STRING & transcription() const
Definition: imagedata.h:145
void OpenWrite(GenericVector< char > *data)
Definition: serialis.cpp:125
#define ASSERT_HOST(x)
Definition: errcode.h:84
const double kStageTransitionThreshold
Definition: lstmtrainer.cpp:62
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
LIST search(LIST list, void *key, int_compare is_equal)
Definition: oldlist.cpp:406
const double kLearningRateDecay
Definition: lstmtrainer.cpp:52
bool ComputeEncoding(const UNICHARSET &unicharset, int null_id, STRING *radical_stroke_table)
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
CachingStrategy CacheStrategy() const
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer)
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:444
int get_script_table_size() const
Definition: unicharset.h:809
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:446
ScrollView * recon_win_
Definition: lstmtrainer.h:396
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
virtual void DebugWeights()
Definition: network.h:204
void StartSubtrainer(STRING *log_msg)
Definition: strngs.h:45
bool has_special_codes() const
Definition: unicharset.h:682
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:154
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:309
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
void LabelsFromOutputs(const NetworkIO &outputs, float null_thr, GenericVector< int > *labels, GenericVector< int > *xcoords)
Network * GetLayer(const STRING &id) const
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:268
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:418
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
Definition: network.h:259
UNICHARSET unicharset
Definition: ccutil.h:68
int FWrite(const void *buffer, int size, int count)
Definition: serialis.cpp:148
bool LoadDocuments(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, FileReader reader)
Definition: imagedata.cpp:572
int learning_iteration() const
Definition: lstmtrainer.h:149
int InitTensorFlowNetwork(const std::string &tf_proto)
void SetUnicharsetProperties(const STRING &script_dir)
STRING DumpFilename() const
void FillErrorBuffer(double new_error, ErrorTypes type)
double learning_rate() const
virtual bool DeSerialize(TFile *fp)
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:438
double ComputeRMSError(const NetworkIO &deltas)
bool LoadAllTrainingData(const GenericVector< STRING > &filenames)
virtual R Run(A1, A2, A3)=0
bool ReadSizedTrainingDump(const char *data, int size)
const GenericVector< TBOX > & boxes() const
Definition: imagedata.h:148
const STRING & imagefilename() const
Definition: imagedata.h:124
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset)
Definition: recodebeam.cpp:76
static LSTMRecognizer * ReadRecognitionDump(const GenericVector< char > &data)
const int kNumPagesPerBatch
Definition: lstmtrainer.cpp:58
bool encode_string(const char *str, bool give_up_on_failure, GenericVector< UNICHAR_ID > *encoding, GenericVector< char > *lengths, int *encoded_length) const
Definition: unicharset.cpp:234
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
#define MIN(x, y)
Definition: ndminx.h:28
bool(* FileWriter)(const GenericVector< char > &data, const STRING &filename)
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:426
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:439
const char * filename
Definition: ioapi.h:38
int NumFeatures() const
Definition: networkio.h:111
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:452
void SetIteration(int iteration)
uint8_t uinT8
Definition: host.h:35
const GENERIC_2D_ARRAY< float > & float_array() const
Definition: networkio.h:139
bool Open(const STRING &filename, FileReader reader)
Definition: serialis.cpp:38
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
bool Serialize(FILE *fp) const
const STRING & name() const
Definition: network.h:138
int size() const
Definition: unicharset.h:299
const int kMinStartedErrorRate
Definition: lstmtrainer.cpp:60
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
double CharError() const
Definition: lstmtrainer.h:139
virtual R Run(A1, A2, A3, A4)=0
const int kTargetXScale
Definition: lstmtrainer.cpp:70
void UpdateErrorBuffer(double new_error, ErrorTypes type)
const int kErrorGraphInterval
Definition: lstmtrainer.cpp:56
virtual STRING spec() const
Definition: network.h:141
bool TransitionTrainingStage(float error_threshold)
STRING DecodeLabels(const GenericVector< int > &labels)
int NumOutputs() const
Definition: network.h:123
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:247
virtual void CountAlternators(const Network &other, double *same, double *changed) const
Definition: network.h:222
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:53
bool load_from_inmemory_file(const char *const memory, int mem_size, bool skip_fragments)
Definition: unicharset.cpp:724
bool Serialize(TFile *fp) const
const double kBestCheckpointFraction
Definition: lstmtrainer.cpp:68
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:112
ScrollView * align_win_
Definition: lstmtrainer.h:390
void split(const char c, GenericVector< STRING > *splited)
Definition: strngs.cpp:286
void CopyFrom(const UNICHARSET &src)
Definition: unicharset.cpp:423
SVEvent * AwaitEvent(SVEventType type)
Definition: scrollview.cpp:449
int FRead(void *buffer, int size, int count)
Definition: serialis.cpp:108
void SaveRecognitionDump(GenericVector< char > *data) const
void ScaleLayerLearningRate(const STRING &id, double factor)
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)