tesseract  4.00.00dev
lstmtrainer.h
Go to the documentation of this file.
1 // File: lstmtrainer.h
3 // Description: Top-level line trainer class for LSTM-based networks.
4 // Author: Ray Smith
5 // Created: Fri May 03 09:07:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_LSTMTRAINER_H_
20 #define TESSERACT_LSTM_LSTMTRAINER_H_
21 
22 #include "imagedata.h"
23 #include "lstmrecognizer.h"
24 #include "rect.h"
25 #include "tesscallback.h"
26 
27 namespace tesseract {
28 
29 class LSTM;
30 class LSTMTrainer;
31 class Parallel;
32 class Reversed;
33 class Softmax;
34 class Series;
35 
36 // Enum for the types of errors that are counted.
37 enum ErrorTypes {
38  ET_RMS, // RMS activation error.
39  ET_DELTA, // Number of big errors in deltas.
40  ET_WORD_RECERR, // Output text string word recall error.
41  ET_CHAR_ERROR, // Output text string total char error.
42  ET_SKIP_RATIO, // Fraction of samples skipped.
43  ET_COUNT // For array sizing.
44 };
45 
46 // Enum for the trainability_ flags.
48  TRAINABLE, // Non-zero delta error.
49  PERFECT, // Zero delta error.
50  UNENCODABLE, // Not trainable due to coding/alignment trouble.
51  HI_PRECISION_ERR, // Hi confidence disagreement.
52  NOT_BOXED, // Early in training and has no character boxes.
53 };
54 
55 // Enum to define the amount of data to get serialized.
57  LIGHT, // Minimal data for remote training.
58  NO_BEST_TRAINER, // Save an empty vector in place of best_trainer_.
59  FULL, // All data including best_trainer_.
60 };
61 
62 // Enum to indicate how the sub_trainer_ training went.
64  STR_NONE, // Did nothing as not good enough.
65  STR_UPDATED, // Subtrainer was updated, but didn't replace *this.
66  STR_REPLACED // Subtrainer replaced *this.
67 };
68 
70 // Function to restore the trainer state from a given checkpoint.
71 // Returns false on failure.
74 // Function to save a checkpoint of the current trainer state.
75 // Returns false on failure. SerializeAmount determines the amount of the
76 // trainer to serialize, typically used for saving the best state.
77 typedef TessResultCallback3<bool, SerializeAmount, const LSTMTrainer*,
79 // Function to compute and record error rates on some external test set(s).
80 // Args are: iteration, mean errors, model, training stage.
81 // Returns a STRING containing logging information about the tests.
82 typedef TessResultCallback4<STRING, int, const double*,
84 
85 // Trainer class for LSTM networks. Most of the effort is in creating the
86 // ideal target outputs from the transcription. A box file is used if it is
87 // available, otherwise estimates of the char widths from the unicharset are
88 // used to guide a DP search for the best fit to the transcription.
89 class LSTMTrainer : public LSTMRecognizer {
90  public:
91  LSTMTrainer();
92  // Callbacks may be null, in which case defaults are used.
93  LSTMTrainer(FileReader file_reader, FileWriter file_writer,
94  CheckPointReader checkpoint_reader,
95  CheckPointWriter checkpoint_writer,
96  const char* model_base, const char* checkpoint_name,
97  int debug_interval, inT64 max_memory);
98  virtual ~LSTMTrainer();
99 
100  // Tries to deserialize a trainer from the given file and silently returns
101  // false in case of failure.
102  bool TryLoadingCheckpoint(const char* filename);
103 
104  // Initializes the character set encode/decode mechanism.
105  // train_flags control training behavior according to the TrainingFlags
106  // enum, including character set encoding.
107  // script_dir is required for TF_COMPRESS_UNICHARSET, and, if provided,
108  // fully initializes the unicharset from the universal unicharsets.
109  // Note: Call before InitNetwork!
110  void InitCharSet(const UNICHARSET& unicharset, const STRING& script_dir,
111  int train_flags);
112  // Initializes the character set encode/decode mechanism directly from a
113  // previously setup UNICHARSET and UnicharCompress.
114  // ctc_mode controls how the truth text is mapped to the network targets.
115  // Note: Call before InitNetwork!
116  void InitCharSet(const UNICHARSET& unicharset,
117  const UnicharCompress& recoder);
118 
119  // Initializes the trainer with a network_spec in the network description
120  // net_flags control network behavior according to the NetworkFlags enum.
121  // There isn't really much difference between them - only where the effects
122  // are implemented.
123  // For other args see NetworkBuilder::InitNetwork.
124  // Note: Be sure to call InitCharSet before InitNetwork!
125  bool InitNetwork(const STRING& network_spec, int append_index, int net_flags,
126  float weight_range, float learning_rate, float momentum);
127  // Initializes a trainer from a serialized TFNetworkModel proto.
128  // Returns the global step of TensorFlow graph or 0 if failed.
129  // Building a compatible TF graph: See tfnetwork.proto.
130  int InitTensorFlowNetwork(const std::string& tf_proto);
131  // Resets all the iteration counters for fine tuning or training a head,
132  // where we want the error reporting to reset.
133  void InitIterations();
134 
135  // Accessors.
136  double ActivationError() const {
137  return error_rates_[ET_DELTA];
138  }
139  double CharError() const { return error_rates_[ET_CHAR_ERROR]; }
140  const double* error_rates() const {
141  return error_rates_;
142  }
143  double best_error_rate() const {
144  return best_error_rate_;
145  }
146  int best_iteration() const {
147  return best_iteration_;
148  }
149  int learning_iteration() const { return learning_iteration_; }
150  int improvement_steps() const { return improvement_steps_; }
151  void set_perfect_delay(int delay) { perfect_delay_ = delay; }
152  const GenericVector<char>& best_trainer() const { return best_trainer_; }
153  // Returns the error that was just calculated by PrepareForBackward.
154  double NewSingleError(ErrorTypes type) const {
156  }
157  // Returns the error that was just calculated by TrainOnLine. Since
158  // TrainOnLine rolls the error buffers, this is one further back than
159  // NewSingleError.
160  double LastSingleError(ErrorTypes type) const {
161  return error_buffers_[type]
164  }
165  const DocumentCache& training_data() const {
166  return training_data_;
167  }
169 
170  // If the training sample is usable, grid searches for the optimal
171  // dict_ratio/cert_offset, and returns the results in a string of space-
172  // separated triplets of ratio,offset=worderr.
174  const ImageData* trainingdata, int iteration, double min_dict_ratio,
175  double dict_ratio_step, double max_dict_ratio, double min_cert_offset,
176  double cert_offset_step, double max_cert_offset, STRING* results);
177 
178  void SetSerializeMode(SerializeAmount serialize_amount) const {
179  serialize_amount_ = serialize_amount;
180  }
181 
182  // Provides output on the distribution of weight values.
183  void DebugNetwork();
184 
185  // Loads a set of lstmf files that were created using the lstm.train config to
186  // tesseract into memory ready for training. Returns false if nothing was
187  // loaded.
188  bool LoadAllTrainingData(const GenericVector<STRING>& filenames);
189 
190  // Keeps track of best and locally worst error rate, using internally computed
191  // values. See MaintainCheckpointsSpecific for more detail.
192  bool MaintainCheckpoints(TestCallback tester, STRING* log_msg);
193  // Keeps track of best and locally worst error_rate (whatever it is) and
194  // launches tests using rec_model, when a new min or max is reached.
195  // Writes checkpoints using train_model at appropriate times and builds and
196  // returns a log message to indicate progress. Returns false if nothing
197  // interesting happened.
198  bool MaintainCheckpointsSpecific(int iteration,
199  const GenericVector<char>* train_model,
200  const GenericVector<char>* rec_model,
201  TestCallback tester, STRING* log_msg);
202  // Builds a string containing a progress message with current error rates.
203  void PrepareLogMsg(STRING* log_msg) const;
204  // Appends <intro_str> iteration learning_iteration()/training_iteration()/
205  // sample_iteration() to the log_msg.
206  void LogIterations(const char* intro_str, STRING* log_msg) const;
207 
208  // TODO(rays) Add curriculum learning.
209  // Returns true and increments the training_stage_ if the error rate has just
210  // passed through the given threshold for the first time.
211  bool TransitionTrainingStage(float error_threshold);
212  // Returns the current training stage.
213  int CurrentTrainingStage() const { return training_stage_; }
214 
215  // Writes to the given file. Returns false in case of error.
216  virtual bool Serialize(TFile* fp) const;
217  // Reads from the given file. Returns false in case of error.
218  virtual bool DeSerialize(TFile* fp);
219 
220  // De-serializes the saved best_trainer_ into sub_trainer_, and adjusts the
221  // learning rates (by scaling reduction, or layer specific, according to
222  // NF_LAYER_SPECIFIC_LR).
223  void StartSubtrainer(STRING* log_msg);
224  // While the sub_trainer_ is behind the current training iteration and its
225  // training error is at least kSubTrainerMarginFraction better than the
226  // current training error, trains the sub_trainer_, and returns STR_UPDATED if
227  // it did anything. If it catches up, and has a better error rate than the
228  // current best, as well as a margin over the current error rate, then the
229  // trainer in *this is replaced with sub_trainer_, and STR_REPLACED is
230  // returned. STR_NONE is returned if the subtrainer wasn't good enough to
231  // receive any training iterations.
232  SubTrainerResult UpdateSubtrainer(STRING* log_msg);
233  // Reduces network learning rates, either for everything, or for layers
234  // independently, according to NF_LAYER_SPECIFIC_LR.
235  void ReduceLearningRates(LSTMTrainer* samples_trainer, STRING* log_msg);
236  // Considers reducing the learning rate independently for each layer down by
237  // factor(<1), or leaving it the same, by double-training the given number of
238  // samples and minimizing the amount of changing of sign of weight updates.
239  // Even if it looks like all weights should remain the same, an adjustment
240  // will be made to guarantee a different result when reverting to an old best.
241  // Returns the number of layer learning rates that were reduced.
242  int ReduceLayerLearningRates(double factor, int num_samples,
243  LSTMTrainer* samples_trainer);
244 
245  // Converts the string to integer class labels, with appropriate null_char_s
246  // in between if not in SimpleTextOutput mode. Returns false on failure.
247  bool EncodeString(const STRING& str, GenericVector<int>* labels) const {
248  return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : NULL,
249  SimpleTextOutput(), null_char_, labels);
250  }
251  // Static version operates on supplied unicharset, encoder, simple_text.
252  static bool EncodeString(const STRING& str, const UNICHARSET& unicharset,
253  const UnicharCompress* recoder, bool simple_text,
254  int null_char, GenericVector<int>* labels);
255 
256  // Converts the network to int if not already.
257  void ConvertToInt() {
258  if ((training_flags_ & TF_INT_MODE) == 0) {
261  }
262  }
263 
264  // Performs forward-backward on the given trainingdata.
265  // Returns the sample that was used or NULL if the next sample was deemed
266  // unusable. samples_trainer could be this or an alternative trainer that
267  // holds the training samples.
268  const ImageData* TrainOnLine(LSTMTrainer* samples_trainer, bool batch) {
269  int sample_index = sample_iteration();
270  const ImageData* image =
271  samples_trainer->training_data_.GetPageBySerial(sample_index);
272  if (image != NULL) {
273  Trainability trainable = TrainOnLine(image, batch);
274  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
275  return NULL; // Sample was unusable.
276  }
277  } else {
279  }
280  return image;
281  }
282  Trainability TrainOnLine(const ImageData* trainingdata, bool batch);
283 
284  // Prepares the ground truth, runs forward, and prepares the targets.
285  // Returns a Trainability enum to indicate the suitability of the sample.
286  Trainability PrepareForBackward(const ImageData* trainingdata,
287  NetworkIO* fwd_outputs, NetworkIO* targets);
288 
289  // Writes the trainer to memory, so that the current training state can be
290  // restored.
291  bool SaveTrainingDump(SerializeAmount serialize_amount,
292  const LSTMTrainer* trainer,
293  GenericVector<char>* data) const;
294 
295  // Reads previously saved trainer from memory.
296  bool ReadTrainingDump(const GenericVector<char>& data, LSTMTrainer* trainer);
297  bool ReadSizedTrainingDump(const char* data, int size);
298 
299  // Sets up the data for MaintainCheckpoints from a light ReadTrainingDump.
300  void SetupCheckpointInfo();
301 
302  // Writes the recognizer to memory, so that it can be used for testing later.
303  void SaveRecognitionDump(GenericVector<char>* data) const;
304 
305  // Reads and returns a previously saved recognizer from memory.
307 
308  // Writes current best model to a file, unless it has already been written.
309  bool SaveBestModel(FileWriter writer) const;
310 
311  // Returns a suitable filename for a training dump, based on the model_base_,
312  // the iteration and the error rates.
313  STRING DumpFilename() const;
314 
315  // Fills the whole error buffer of the given type with the given value.
316  void FillErrorBuffer(double new_error, ErrorTypes type);
317 
318  protected:
319  // Factored sub-constructor sets up reasonable default values.
320  void EmptyConstructor();
321 
322  // Sets the unicharset properties using the given script_dir as a source of
323  // script unicharsets. If the flag TF_COMPRESS_UNICHARSET is true, also sets
324  // up the recoder_ to simplify the unicharset.
325  void SetUnicharsetProperties(const STRING& script_dir);
326 
327  // Outputs the string and periodically displays the given network inputs
328  // as an image in the given window, and the corresponding labels at the
329  // corresponding x_starts.
330  // Returns false if the truth string is empty.
331  bool DebugLSTMTraining(const NetworkIO& inputs,
332  const ImageData& trainingdata,
333  const NetworkIO& fwd_outputs,
334  const GenericVector<int>& truth_labels,
335  const NetworkIO& outputs);
336  // Displays the network targets as line a line graph.
337  void DisplayTargets(const NetworkIO& targets, const char* window_name,
338  ScrollView** window);
339 
340  // Builds a no-compromises target where the first positions should be the
341  // truth labels and the rest is padded with the null_char_.
342  bool ComputeTextTargets(const NetworkIO& outputs,
343  const GenericVector<int>& truth_labels,
344  NetworkIO* targets);
345 
346  // Builds a target using standard CTC. truth_labels should be pre-padded with
347  // nulls wherever desired. They don't have to be between all labels.
348  // outputs is input-output, as it gets clipped to minimum probability.
349  bool ComputeCTCTargets(const GenericVector<int>& truth_labels,
350  NetworkIO* outputs, NetworkIO* targets);
351 
352  // Computes network errors, and stores the results in the rolling buffers,
353  // along with the supplied text_error.
354  // Returns the delta error of the current sample (not running average.)
355  double ComputeErrorRates(const NetworkIO& deltas, double char_error,
356  double word_error);
357 
358  // Computes the network activation RMS error rate.
359  double ComputeRMSError(const NetworkIO& deltas);
360 
361  // Computes network activation winner error rate. (Number of values that are
362  // in error by >= 0.5 divided by number of time-steps.) More closely related
363  // to final character error than RMS, but still directly calculable from
364  // just the deltas. Because of the binary nature of the targets, zero winner
365  // error is a sufficient but not necessary condition for zero char error.
366  double ComputeWinnerError(const NetworkIO& deltas);
367 
368  // Computes a very simple bag of chars char error rate.
369  double ComputeCharError(const GenericVector<int>& truth_str,
370  const GenericVector<int>& ocr_str);
371  // Computes a very simple bag of words word recall error rate.
372  // NOTE that this is destructive on both input strings.
373  double ComputeWordError(STRING* truth_str, STRING* ocr_str);
374 
375  // Updates the error buffer and corresponding mean of the given type with
376  // the new_error.
377  void UpdateErrorBuffer(double new_error, ErrorTypes type);
378 
379  // Rolls error buffers and reports the current means.
380  void RollErrorBuffers();
381 
382  // Given that error_rate is either a new min or max, updates the best/worst
383  // error rates, and record of progress.
384  STRING UpdateErrorGraph(int iteration, double error_rate,
385  const GenericVector<char>& model_data,
386  TestCallback tester);
387 
388  protected:
389  // Alignment display window.
391  // CTC target display window.
393  // CTC output display window.
395  // Reconstructed image window.
397  // How often to display a debug image.
399  // Iteration at which the last checkpoint was dumped.
401  // Basename of files to save best models to.
402  STRING model_base_;
403  // Checkpoint filename.
405  // Training data.
407  // A hack to serialize less data for batch training and record file version.
408  mutable SerializeAmount serialize_amount_;
409  // Name to use when saving best_trainer_.
411  // Number of available training stages.
413  // Checkpointing callbacks.
416  // TODO(rays) These are pointers, and must be deleted. Switch to unique_ptr
417  // when we can commit to c++11.
418  CheckPointReader checkpoint_reader_;
419  CheckPointWriter checkpoint_writer_;
420 
421  // ===Serialized data to ensure that a restart produces the same results.===
422  // These members are only serialized when serialize_amount_ != LIGHT.
423  // Best error rate so far.
425  // Snapshot of all error rates at best_iteration_.
427  // Iteration of best_error_rate_.
429  // Worst error rate since best_error_rate_.
431  // Snapshot of all error rates at worst_iteration_.
433  // Iteration of worst_error_rate_.
435  // Iteration at which the process will be thought stalled.
437  // Saved recognition models for computing test error for graph points.
440  // Saved trainer for reverting back to last known best.
442  // A subsidiary trainer running with a different learning rate until either
443  // *this or sub_trainer_ hits a new best.
444  LSTMTrainer* sub_trainer_;
445  // Error rate at which last best model was dumped.
447  // Current stage of training.
449  // History of best error rate against iteration. Used for computing the
450  // number of steps to each 2% improvement.
453  // Number of iterations since the best_error_rate_ was 2% more than it is now.
455  // Number of iterations that yielded a non-zero delta error and thus provided
456  // significant learning. learning_iteration_ <= training_iteration_.
457  // learning_iteration_ is used to measure rate of learning progress.
459  // Saved value of sample_iteration_ before looking for the the next sample.
461  // How often to include a PERFECT training sample in backprop.
462  // A PERFECT training sample is used if the current
463  // training_iteration_ > last_perfect_training_iteration_ + perfect_delay_,
464  // so with perfect_delay_ == 0, all samples are used, and with
465  // perfect_delay_ == 4, at most 1 in 5 samples will be perfect.
467  // Value of training_iteration_ at which the last PERFECT training sample
468  // was used in back prop.
470  // Rolling buffers storing recent training errors are indexed by
471  // training_iteration % kRollingBufferSize_.
472  static const int kRollingBufferSize_ = 1000;
474  // Rounded mean percent trailing training errors in the buffers.
475  double error_rates_[ET_COUNT]; // RMS training error.
476 };
477 
478 } // namespace tesseract.
479 
480 #endif // TESSERACT_LSTM_LSTMTRAINER_H_
void PrepareLogMsg(STRING *log_msg) const
double best_error_rate() const
Definition: lstmtrainer.h:143
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
const ImageData * GetPageBySerial(int serial)
Definition: imagedata.h:335
int CurrentTrainingStage() const
Definition: lstmtrainer.h:213
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:441
ScrollView * ctc_win_
Definition: lstmtrainer.h:394
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:419
int64_t inT64
Definition: host.h:40
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:475
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
void InitCharSet(const UNICHARSET &unicharset, const STRING &script_dir, int train_flags)
virtual bool Serialize(TFile *fp) const
double ComputeWinnerError(const NetworkIO &deltas)
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:473
bool(* FileReader)(const STRING &filename, GenericVector< char > *data)
SerializeAmount serialize_amount_
Definition: lstmtrainer.h:408
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:451
const double * error_rates() const
Definition: lstmtrainer.h:140
voidpf void uLong size
Definition: ioapi.h:39
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum)
DocumentCache * mutable_training_data()
Definition: lstmtrainer.h:168
const UNICHARSET & GetUnicharset() const
int best_iteration() const
Definition: lstmtrainer.h:146
static const int kRollingBufferSize_
Definition: lstmtrainer.h:472
ScrollView * target_win_
Definition: lstmtrainer.h:392
void LogIterations(const char *intro_str, STRING *log_msg) const
bool TryLoadingCheckpoint(const char *filename)
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
DocumentCache training_data_
Definition: lstmtrainer.h:406
double ActivationError() const
Definition: lstmtrainer.h:136
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
void set_perfect_delay(int delay)
Definition: lstmtrainer.h:151
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer)
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:444
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:446
ScrollView * recon_win_
Definition: lstmtrainer.h:396
virtual void ConvertToInt()
Definition: network.h:177
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
void StartSubtrainer(STRING *log_msg)
Definition: strngs.h:45
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:154
const DocumentCache & training_data() const
Definition: lstmtrainer.h:165
TessResultCallback3< bool, SerializeAmount, const LSTMTrainer *, GenericVector< char > * > * CheckPointWriter
Definition: lstmtrainer.h:78
void SetSerializeMode(SerializeAmount serialize_amount) const
Definition: lstmtrainer.h:178
bool SaveBestModel(FileWriter writer) const
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:268
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:418
int learning_iteration() const
Definition: lstmtrainer.h:149
int InitTensorFlowNetwork(const std::string &tf_proto)
void SetUnicharsetProperties(const STRING &script_dir)
STRING DumpFilename() const
void FillErrorBuffer(double new_error, ErrorTypes type)
double learning_rate() const
virtual bool DeSerialize(TFile *fp)
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:438
double ComputeRMSError(const NetworkIO &deltas)
const GenericVector< char > & best_trainer() const
Definition: lstmtrainer.h:152
bool LoadAllTrainingData(const GenericVector< STRING > &filenames)
bool ReadSizedTrainingDump(const char *data, int size)
static LSTMRecognizer * ReadRecognitionDump(const GenericVector< char > &data)
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
bool(* FileWriter)(const GenericVector< char > &data, const STRING &filename)
double LastSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:160
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:426
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:439
TessResultCallback2< bool, const GenericVector< char > &, LSTMTrainer * > * CheckPointReader
Definition: lstmtrainer.h:69
const char * filename
Definition: ioapi.h:38
typedef int(ZCALLBACK *close_file_func) OF((voidpf opaque
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:452
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
double CharError() const
Definition: lstmtrainer.h:139
void UpdateErrorBuffer(double new_error, ErrorTypes type)
bool TransitionTrainingStage(float error_threshold)
int improvement_steps() const
Definition: lstmtrainer.h:150
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:247
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
bool MaintainCheckpointsSpecific(int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg)
ScrollView * align_win_
Definition: lstmtrainer.h:390
TessResultCallback4< STRING, int, const double *, const GenericVector< char > &, int > * TestCallback
Definition: lstmtrainer.h:83
void SaveRecognitionDump(GenericVector< char > *data) const
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)