tesseract  4.00.00dev
tesseract::LSTMTrainer Class Reference

#include <lstmtrainer.h>

Inheritance diagram for tesseract::LSTMTrainer:
tesseract::LSTMRecognizer

Public Member Functions

 LSTMTrainer ()
 
 LSTMTrainer (FileReader file_reader, FileWriter file_writer, CheckPointReader checkpoint_reader, CheckPointWriter checkpoint_writer, const char *model_base, const char *checkpoint_name, int debug_interval, inT64 max_memory)
 
virtual ~LSTMTrainer ()
 
bool TryLoadingCheckpoint (const char *filename)
 
void InitCharSet (const UNICHARSET &unicharset, const STRING &script_dir, int train_flags)
 
void InitCharSet (const UNICHARSET &unicharset, const UnicharCompress &recoder)
 
bool InitNetwork (const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum)
 
int InitTensorFlowNetwork (const std::string &tf_proto)
 
void InitIterations ()
 
double ActivationError () const
 
double CharError () const
 
const double * error_rates () const
 
double best_error_rate () const
 
int best_iteration () const
 
int learning_iteration () const
 
int improvement_steps () const
 
void set_perfect_delay (int delay)
 
const GenericVector< char > & best_trainer () const
 
double NewSingleError (ErrorTypes type) const
 
double LastSingleError (ErrorTypes type) const
 
const DocumentCachetraining_data () const
 
DocumentCachemutable_training_data ()
 
Trainability GridSearchDictParams (const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)
 
void SetSerializeMode (SerializeAmount serialize_amount) const
 
void DebugNetwork ()
 
bool LoadAllTrainingData (const GenericVector< STRING > &filenames)
 
bool MaintainCheckpoints (TestCallback tester, STRING *log_msg)
 
bool MaintainCheckpointsSpecific (int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg)
 
void PrepareLogMsg (STRING *log_msg) const
 
void LogIterations (const char *intro_str, STRING *log_msg) const
 
bool TransitionTrainingStage (float error_threshold)
 
int CurrentTrainingStage () const
 
virtual bool Serialize (TFile *fp) const
 
virtual bool DeSerialize (TFile *fp)
 
void StartSubtrainer (STRING *log_msg)
 
SubTrainerResult UpdateSubtrainer (STRING *log_msg)
 
void ReduceLearningRates (LSTMTrainer *samples_trainer, STRING *log_msg)
 
int ReduceLayerLearningRates (double factor, int num_samples, LSTMTrainer *samples_trainer)
 
bool EncodeString (const STRING &str, GenericVector< int > *labels) const
 
void ConvertToInt ()
 
const ImageDataTrainOnLine (LSTMTrainer *samples_trainer, bool batch)
 
Trainability TrainOnLine (const ImageData *trainingdata, bool batch)
 
Trainability PrepareForBackward (const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
 
bool SaveTrainingDump (SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
 
bool ReadTrainingDump (const GenericVector< char > &data, LSTMTrainer *trainer)
 
bool ReadSizedTrainingDump (const char *data, int size)
 
void SetupCheckpointInfo ()
 
void SaveRecognitionDump (GenericVector< char > *data) const
 
bool SaveBestModel (FileWriter writer) const
 
STRING DumpFilename () const
 
void FillErrorBuffer (double new_error, ErrorTypes type)
 
- Public Member Functions inherited from tesseract::LSTMRecognizer
 LSTMRecognizer ()
 
 ~LSTMRecognizer ()
 
int NumOutputs () const
 
int training_iteration () const
 
int sample_iteration () const
 
double learning_rate () const
 
bool IsHardening () const
 
LossType OutputLossType () const
 
bool SimpleTextOutput () const
 
bool IsIntMode () const
 
bool IsRecoding () const
 
CachingStrategy CacheStrategy () const
 
bool IsTensorFlow () const
 
GenericVector< STRINGEnumerateLayers () const
 
NetworkGetLayer (const STRING &id) const
 
float GetLayerLearningRate (const STRING &id) const
 
void ScaleLearningRate (double factor)
 
void ScaleLayerLearningRate (const STRING &id, double factor)
 
bool IsUsingAdaGrad () const
 
const UNICHARSETGetUnicharset () const
 
const DictGetDict () const
 
void SetIteration (int iteration)
 
int NumInputs () const
 
int null_char () const
 
bool Serialize (TFile *fp) const
 
bool DeSerialize (TFile *fp)
 
bool LoadDictionary (const char *lang, TessdataManager *mgr)
 
void RecognizeLine (const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, bool use_alternates, const UNICHARSET *target_unicharset, const TBOX &line_box, float score_ratio, bool one_word, PointerVector< WERD_RES > *words)
 
void WordsFromOutputs (const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > label_coords, const TBOX &line_box, bool debug, bool use_alternates, bool one_word, float score_ratio, float scale_factor, const UNICHARSET *target_unicharset, PointerVector< WERD_RES > *words)
 
void OutputStats (const NetworkIO &outputs, float *min_output, float *mean_output, float *sd)
 
bool RecognizeLine (const ImageData &image_data, bool invert, bool debug, bool re_invert, float label_threshold, float *scale_factor, NetworkIO *inputs, NetworkIO *outputs)
 
WERD_RESWordFromOutput (const TBOX &line_box, const NetworkIO &outputs, int word_start, int word_end, float score_ratio, float space_certainty, bool debug, bool use_alternates, const UNICHARSET *target_unicharset, const GenericVector< int > &labels, const GenericVector< int > &label_coords, float scale_factor)
 
WERD_RESInitializeWord (const TBOX &line_box, int word_start, int word_end, float space_certainty, bool use_alternates, const UNICHARSET *target_unicharset, const GenericVector< int > &labels, const GenericVector< int > &label_coords, float scale_factor)
 
STRING DecodeLabels (const GenericVector< int > &labels)
 
void DisplayForward (const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
 

Static Public Member Functions

static bool EncodeString (const STRING &str, const UNICHARSET &unicharset, const UnicharCompress *recoder, bool simple_text, int null_char, GenericVector< int > *labels)
 
static LSTMRecognizerReadRecognitionDump (const GenericVector< char > &data)
 

Protected Member Functions

void EmptyConstructor ()
 
void SetUnicharsetProperties (const STRING &script_dir)
 
bool DebugLSTMTraining (const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
 
void DisplayTargets (const NetworkIO &targets, const char *window_name, ScrollView **window)
 
bool ComputeTextTargets (const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
 
bool ComputeCTCTargets (const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
 
double ComputeErrorRates (const NetworkIO &deltas, double char_error, double word_error)
 
double ComputeRMSError (const NetworkIO &deltas)
 
double ComputeWinnerError (const NetworkIO &deltas)
 
double ComputeCharError (const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
 
double ComputeWordError (STRING *truth_str, STRING *ocr_str)
 
void UpdateErrorBuffer (double new_error, ErrorTypes type)
 
void RollErrorBuffers ()
 
STRING UpdateErrorGraph (int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
 
- Protected Member Functions inherited from tesseract::LSTMRecognizer
void SetRandomSeed ()
 
void DisplayLSTMOutput (const GenericVector< int > &labels, const GenericVector< int > &xcoords, int height, ScrollView *window)
 
void DebugActivationPath (const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
 
void DebugActivationRange (const NetworkIO &outputs, const char *label, int best_choice, int x_start, int x_end)
 
void LabelsFromOutputs (const NetworkIO &outputs, float null_thr, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
void LabelsViaThreshold (const NetworkIO &output, float null_threshold, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
void LabelsViaCTC (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
void LabelsViaReEncode (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
void LabelsViaSimpleText (const NetworkIO &output, GenericVector< int > *labels, GenericVector< int > *xcoords)
 
BLOB_CHOICE_LIST * GetBlobChoices (int col, int row, bool debug, const NetworkIO &output, const UNICHARSET *target_unicharset, int x_start, int x_end, float score_ratio)
 
bool AddBlobChoices (int unichar_id, float rating, float certainty, int col, int row, const UNICHARSET *target_unicharset, BLOB_CHOICE_IT *bc_it)
 
const char * DecodeLabel (const GenericVector< int > &labels, int start, int *end, int *decoded)
 
const char * DecodeSingleLabel (int label)
 

Protected Attributes

ScrollViewalign_win_
 
ScrollViewtarget_win_
 
ScrollViewctc_win_
 
ScrollViewrecon_win_
 
int debug_interval_
 
int checkpoint_iteration_
 
STRING model_base_
 
STRING checkpoint_name_
 
DocumentCache training_data_
 
SerializeAmount serialize_amount_
 
STRING best_model_name_
 
int num_training_stages_
 
FileReader file_reader_
 
FileWriter file_writer_
 
CheckPointReader checkpoint_reader_
 
CheckPointWriter checkpoint_writer_
 
double best_error_rate_
 
double best_error_rates_ [ET_COUNT]
 
int best_iteration_
 
double worst_error_rate_
 
double worst_error_rates_ [ET_COUNT]
 
int worst_iteration_
 
int stall_iteration_
 
GenericVector< char > best_model_data_
 
GenericVector< char > worst_model_data_
 
GenericVector< char > best_trainer_
 
LSTMTrainersub_trainer_
 
float error_rate_of_last_saved_best_
 
int training_stage_
 
GenericVector< double > best_error_history_
 
GenericVector< intbest_error_iterations_
 
int improvement_steps_
 
int learning_iteration_
 
int prev_sample_iteration_
 
int perfect_delay_
 
int last_perfect_training_iteration_
 
GenericVector< double > error_buffers_ [ET_COUNT]
 
double error_rates_ [ET_COUNT]
 
- Protected Attributes inherited from tesseract::LSTMRecognizer
Networknetwork_
 
CCUtil ccutil_
 
UnicharCompress recoder_
 
STRING network_str_
 
inT32 training_flags_
 
inT32 training_iteration_
 
inT32 sample_iteration_
 
inT32 null_char_
 
float weight_range_
 
float learning_rate_
 
float momentum_
 
TRand randomizer_
 
NetworkScratch scratch_space_
 
Dictdict_
 
RecodeBeamSearchsearch_
 
ScrollViewdebug_win_
 

Static Protected Attributes

static const int kRollingBufferSize_ = 1000
 

Detailed Description

Definition at line 89 of file lstmtrainer.h.

Constructor & Destructor Documentation

◆ LSTMTrainer() [1/2]

tesseract::LSTMTrainer::LSTMTrainer ( )

Definition at line 73 of file lstmtrainer.cpp.

74  : training_data_(0),
81  sub_trainer_(NULL) {
83  debug_interval_ = 0;
84 }
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:419
_ConstTessMemberResultCallback_0_0< false, R, T1 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)() const)
Definition: tesscallback.h:116
bool SaveDataToFile(const GenericVector< char > &data, const STRING &filename)
DocumentCache training_data_
Definition: lstmtrainer.h:406
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer)
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:444
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:418

◆ LSTMTrainer() [2/2]

tesseract::LSTMTrainer::LSTMTrainer ( FileReader  file_reader,
FileWriter  file_writer,
CheckPointReader  checkpoint_reader,
CheckPointWriter  checkpoint_writer,
const char *  model_base,
const char *  checkpoint_name,
int  debug_interval,
inT64  max_memory 
)

Definition at line 86 of file lstmtrainer.cpp.

91  : training_data_(max_memory),
92  file_reader_(file_reader),
93  file_writer_(file_writer),
94  checkpoint_reader_(checkpoint_reader),
95  checkpoint_writer_(checkpoint_writer),
96  sub_trainer_(NULL) {
100  if (checkpoint_reader_ == NULL) {
103  }
104  if (checkpoint_writer_ == NULL) {
107  }
108  debug_interval_ = debug_interval;
109  model_base_ = model_base;
110  checkpoint_name_ = checkpoint_name;
111 }
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:419
_ConstTessMemberResultCallback_0_0< false, R, T1 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)() const)
Definition: tesscallback.h:116
bool SaveDataToFile(const GenericVector< char > &data, const STRING &filename)
DocumentCache training_data_
Definition: lstmtrainer.h:406
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer)
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:444
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:418

◆ ~LSTMTrainer()

tesseract::LSTMTrainer::~LSTMTrainer ( )
virtual

Definition at line 113 of file lstmtrainer.cpp.

113  {
114  delete align_win_;
115  delete target_win_;
116  delete ctc_win_;
117  delete recon_win_;
118  delete checkpoint_reader_;
119  delete checkpoint_writer_;
120  delete sub_trainer_;
121 }
ScrollView * ctc_win_
Definition: lstmtrainer.h:394
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:419
ScrollView * target_win_
Definition: lstmtrainer.h:392
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:444
ScrollView * recon_win_
Definition: lstmtrainer.h:396
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:418
ScrollView * align_win_
Definition: lstmtrainer.h:390

Member Function Documentation

◆ ActivationError()

double tesseract::LSTMTrainer::ActivationError ( ) const
inline

Definition at line 136 of file lstmtrainer.h.

136  {
137  return error_rates_[ET_DELTA];
138  }
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:475

◆ best_error_rate()

double tesseract::LSTMTrainer::best_error_rate ( ) const
inline

Definition at line 143 of file lstmtrainer.h.

143  {
144  return best_error_rate_;
145  }

◆ best_iteration()

int tesseract::LSTMTrainer::best_iteration ( ) const
inline

Definition at line 146 of file lstmtrainer.h.

146  {
147  return best_iteration_;
148  }

◆ best_trainer()

const GenericVector<char>& tesseract::LSTMTrainer::best_trainer ( ) const
inline

Definition at line 152 of file lstmtrainer.h.

152 { return best_trainer_; }
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:441

◆ CharError()

double tesseract::LSTMTrainer::CharError ( ) const
inline

Definition at line 139 of file lstmtrainer.h.

◆ ComputeCharError()

double tesseract::LSTMTrainer::ComputeCharError ( const GenericVector< int > &  truth_str,
const GenericVector< int > &  ocr_str 
)
protected

Definition at line 1186 of file lstmtrainer.cpp.

1187  {
1188  GenericVector<int> label_counts;
1189  label_counts.init_to_size(NumOutputs(), 0);
1190  int truth_size = 0;
1191  for (int i = 0; i < truth_str.size(); ++i) {
1192  if (truth_str[i] != null_char_) {
1193  ++label_counts[truth_str[i]];
1194  ++truth_size;
1195  }
1196  }
1197  for (int i = 0; i < ocr_str.size(); ++i) {
1198  if (ocr_str[i] != null_char_) {
1199  --label_counts[ocr_str[i]];
1200  }
1201  }
1202  int char_errors = 0;
1203  for (int i = 0; i < label_counts.size(); ++i) {
1204  char_errors += abs(label_counts[i]);
1205  }
1206  if (truth_size == 0) {
1207  return (char_errors == 0) ? 0.0 : 1.0;
1208  }
1209  return static_cast<double>(char_errors) / truth_size;
1210 }
void init_to_size(int size, T t)
int size() const
Definition: genericvector.h:72

◆ ComputeCTCTargets()

bool tesseract::LSTMTrainer::ComputeCTCTargets ( const GenericVector< int > &  truth_labels,
NetworkIO outputs,
NetworkIO targets 
)
protected

Definition at line 1118 of file lstmtrainer.cpp.

1119  {
1120  // Bottom-clip outputs to a minimum probability.
1121  CTC::NormalizeProbs(outputs);
1122  return CTC::ComputeCTCTargets(truth_labels, null_char_,
1123  outputs->float_array(), targets);
1124 }
static void NormalizeProbs(NetworkIO *probs)
Definition: ctc.h:36
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
Definition: ctc.cpp:53

◆ ComputeErrorRates()

double tesseract::LSTMTrainer::ComputeErrorRates ( const NetworkIO deltas,
double  char_error,
double  word_error 
)
protected

Definition at line 1129 of file lstmtrainer.cpp.

1130  {
1132  // Delta error is the fraction of timesteps with >0.5 error in the top choice
1133  // score. If zero, then the top choice characters are guaranteed correct,
1134  // even when there is residue in the RMS error.
1135  double delta_error = ComputeWinnerError(deltas);
1136  UpdateErrorBuffer(delta_error, ET_DELTA);
1137  UpdateErrorBuffer(word_error, ET_WORD_RECERR);
1138  UpdateErrorBuffer(char_error, ET_CHAR_ERROR);
1139  // Skip ratio measures the difference between sample_iteration_ and
1140  // training_iteration_, which reflects the number of unusable samples,
1141  // usually due to unencodable truth text, or the text not fitting in the
1142  // space for the output.
1143  double skip_count = sample_iteration_ - prev_sample_iteration_;
1144  UpdateErrorBuffer(skip_count, ET_SKIP_RATIO);
1145  return delta_error;
1146 }
double ComputeWinnerError(const NetworkIO &deltas)
double ComputeRMSError(const NetworkIO &deltas)
void UpdateErrorBuffer(double new_error, ErrorTypes type)

◆ ComputeRMSError()

double tesseract::LSTMTrainer::ComputeRMSError ( const NetworkIO deltas)
protected

Definition at line 1149 of file lstmtrainer.cpp.

1149  {
1150  double total_error = 0.0;
1151  int width = deltas.Width();
1152  int num_classes = deltas.NumFeatures();
1153  for (int t = 0; t < width; ++t) {
1154  const float* class_errs = deltas.f(t);
1155  for (int c = 0; c < num_classes; ++c) {
1156  double error = class_errs[c];
1157  total_error += error * error;
1158  }
1159  }
1160  return sqrt(total_error / (width * num_classes));
1161 }

◆ ComputeTextTargets()

bool tesseract::LSTMTrainer::ComputeTextTargets ( const NetworkIO outputs,
const GenericVector< int > &  truth_labels,
NetworkIO targets 
)
protected

Definition at line 1098 of file lstmtrainer.cpp.

1100  {
1101  if (truth_labels.size() > targets->Width()) {
1102  tprintf("Error: transcription %s too long to fit into target of width %d\n",
1103  DecodeLabels(truth_labels).string(), targets->Width());
1104  return false;
1105  }
1106  for (int i = 0; i < truth_labels.size() && i < targets->Width(); ++i) {
1107  targets->SetActivations(i, truth_labels[i], 1.0);
1108  }
1109  for (int i = truth_labels.size(); i < targets->Width(); ++i) {
1110  targets->SetActivations(i, null_char_, 1.0);
1111  }
1112  return true;
1113 }
#define tprintf(...)
Definition: tprintf.h:31
int size() const
Definition: genericvector.h:72
STRING DecodeLabels(const GenericVector< int > &labels)

◆ ComputeWinnerError()

double tesseract::LSTMTrainer::ComputeWinnerError ( const NetworkIO deltas)
protected

Definition at line 1168 of file lstmtrainer.cpp.

1168  {
1169  int num_errors = 0;
1170  int width = deltas.Width();
1171  int num_classes = deltas.NumFeatures();
1172  for (int t = 0; t < width; ++t) {
1173  const float* class_errs = deltas.f(t);
1174  for (int c = 0; c < num_classes; ++c) {
1175  float abs_delta = fabs(class_errs[c]);
1176  // TODO(rays) Filtering cases where the delta is very large to cut out
1177  // GT errors doesn't work. Find a better way or get better truth.
1178  if (0.5 <= abs_delta)
1179  ++num_errors;
1180  }
1181  }
1182  return static_cast<double>(num_errors) / width;
1183 }

◆ ComputeWordError()

double tesseract::LSTMTrainer::ComputeWordError ( STRING truth_str,
STRING ocr_str 
)
protected

Definition at line 1214 of file lstmtrainer.cpp.

1214  {
1215  typedef std::unordered_map<std::string, int, std::hash<std::string> > StrMap;
1216  GenericVector<STRING> truth_words, ocr_words;
1217  truth_str->split(' ', &truth_words);
1218  if (truth_words.empty()) return 0.0;
1219  ocr_str->split(' ', &ocr_words);
1220  StrMap word_counts;
1221  for (int i = 0; i < truth_words.size(); ++i) {
1222  std::string truth_word(truth_words[i].string());
1223  StrMap::iterator it = word_counts.find(truth_word);
1224  if (it == word_counts.end())
1225  word_counts.insert(std::make_pair(truth_word, 1));
1226  else
1227  ++it->second;
1228  }
1229  for (int i = 0; i < ocr_words.size(); ++i) {
1230  std::string ocr_word(ocr_words[i].string());
1231  StrMap::iterator it = word_counts.find(ocr_word);
1232  if (it == word_counts.end())
1233  word_counts.insert(std::make_pair(ocr_word, -1));
1234  else
1235  --it->second;
1236  }
1237  int word_recall_errs = 0;
1238  for (StrMap::const_iterator it = word_counts.begin(); it != word_counts.end();
1239  ++it) {
1240  if (it->second > 0) word_recall_errs += it->second;
1241  }
1242  return static_cast<double>(word_recall_errs) / truth_words.size();
1243 }
bool empty() const
Definition: genericvector.h:90
int size() const
Definition: genericvector.h:72
void split(const char c, GenericVector< STRING > *splited)
Definition: strngs.cpp:286

◆ ConvertToInt()

void tesseract::LSTMTrainer::ConvertToInt ( )
inline

Definition at line 257 of file lstmtrainer.h.

257  {
258  if ((training_flags_ & TF_INT_MODE) == 0) {
261  }
262  }
virtual void ConvertToInt()
Definition: network.h:177

◆ CurrentTrainingStage()

int tesseract::LSTMTrainer::CurrentTrainingStage ( ) const
inline

Definition at line 213 of file lstmtrainer.h.

213 { return training_stage_; }

◆ DebugLSTMTraining()

bool tesseract::LSTMTrainer::DebugLSTMTraining ( const NetworkIO inputs,
const ImageData trainingdata,
const NetworkIO fwd_outputs,
const GenericVector< int > &  truth_labels,
const NetworkIO outputs 
)
protected

Definition at line 1028 of file lstmtrainer.cpp.

1032  {
1033  const STRING& truth_text = DecodeLabels(truth_labels);
1034  if (truth_text.string() == NULL || truth_text.length() <= 0) {
1035  tprintf("Empty truth string at decode time!\n");
1036  return false;
1037  }
1038  if (debug_interval_ != 0) {
1039  // Get class labels, xcoords and string.
1040  GenericVector<int> labels;
1041  GenericVector<int> xcoords;
1042  LabelsFromOutputs(outputs, 0.0f, &labels, &xcoords);
1043  STRING text = DecodeLabels(labels);
1044  tprintf("Iteration %d: ALIGNED TRUTH : %s\n",
1045  training_iteration(), text.string());
1046  if (debug_interval_ > 0 && training_iteration() % debug_interval_ == 0) {
1047  tprintf("TRAINING activation path for truth string %s\n",
1048  truth_text.string());
1049  DebugActivationPath(outputs, labels, xcoords);
1050  DisplayForward(inputs, labels, xcoords, "LSTMTraining", &align_win_);
1051  if (OutputLossType() == LT_CTC) {
1052  DisplayTargets(fwd_outputs, "CTC Outputs", &ctc_win_);
1053  DisplayTargets(outputs, "CTC Targets", &target_win_);
1054  }
1055  }
1056  }
1057  return true;
1058 }
ScrollView * ctc_win_
Definition: lstmtrainer.h:394
LossType OutputLossType() const
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
ScrollView * target_win_
Definition: lstmtrainer.h:392
inT32 length() const
Definition: strngs.cpp:193
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
Definition: strngs.h:45
void LabelsFromOutputs(const NetworkIO &outputs, float null_thr, GenericVector< int > *labels, GenericVector< int > *xcoords)
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
STRING DecodeLabels(const GenericVector< int > &labels)
ScrollView * align_win_
Definition: lstmtrainer.h:390

◆ DebugNetwork()

void tesseract::LSTMTrainer::DebugNetwork ( )

Definition at line 298 of file lstmtrainer.cpp.

298  {
300 }
virtual void DebugWeights()
Definition: network.h:204

◆ DeSerialize()

bool tesseract::LSTMTrainer::DeSerialize ( TFile fp)
virtual

Definition at line 483 of file lstmtrainer.cpp.

483  {
484  if (!LSTMRecognizer::DeSerialize(fp)) return false;
485  if (fp->FRead(&learning_iteration_, sizeof(learning_iteration_), 1) != 1) {
486  // Special case. If we successfully decoded the recognizer, but fail here
487  // then it means we were just given a recognizer, so issue a warning and
488  // allow it.
489  tprintf("Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
492  return true;
493  }
494  if (fp->FReadEndian(&prev_sample_iteration_, sizeof(prev_sample_iteration_),
495  1) != 1)
496  return false;
497  if (fp->FReadEndian(&perfect_delay_, sizeof(perfect_delay_), 1) != 1)
498  return false;
499  if (fp->FReadEndian(&last_perfect_training_iteration_,
500  sizeof(last_perfect_training_iteration_), 1) != 1)
501  return false;
502  for (int i = 0; i < ET_COUNT; ++i) {
503  if (!error_buffers_[i].DeSerialize(fp)) return false;
504  }
505  if (fp->FRead(&error_rates_, sizeof(error_rates_), 1) != 1) return false;
506  if (fp->FReadEndian(&training_stage_, sizeof(training_stage_), 1) != 1)
507  return false;
508  uinT8 amount;
509  if (fp->FRead(&amount, sizeof(amount), 1) != 1) return false;
510  if (amount == LIGHT) return true; // Don't read the rest.
511  if (fp->FReadEndian(&best_error_rate_, sizeof(best_error_rate_), 1) != 1)
512  return false;
513  if (fp->FReadEndian(&best_error_rates_, sizeof(best_error_rates_), 1) != 1)
514  return false;
515  if (fp->FReadEndian(&best_iteration_, sizeof(best_iteration_), 1) != 1)
516  return false;
517  if (fp->FReadEndian(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1)
518  return false;
519  if (fp->FReadEndian(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1)
520  return false;
521  if (fp->FReadEndian(&worst_iteration_, sizeof(worst_iteration_), 1) != 1)
522  return false;
523  if (fp->FReadEndian(&stall_iteration_, sizeof(stall_iteration_), 1) != 1)
524  return false;
525  if (!best_model_data_.DeSerialize(fp)) return false;
526  if (!worst_model_data_.DeSerialize(fp)) return false;
527  if (amount != NO_BEST_TRAINER && !best_trainer_.DeSerialize(fp)) return false;
528  GenericVector<char> sub_data;
529  if (!sub_data.DeSerialize(fp)) return false;
530  delete sub_trainer_;
531  if (sub_data.empty()) {
532  sub_trainer_ = NULL;
533  } else {
534  sub_trainer_ = new LSTMTrainer();
535  if (!ReadTrainingDump(sub_data, sub_trainer_)) return false;
536  }
537  if (!best_error_history_.DeSerialize(fp)) return false;
538  if (!best_error_iterations_.DeSerialize(fp)) return false;
539  if (fp->FReadEndian(&improvement_steps_, sizeof(improvement_steps_), 1) != 1)
540  return false;
541  return true;
542 }
bool DeSerialize(bool swap, FILE *fp)
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:441
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:475
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:473
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:451
#define tprintf(...)
Definition: tprintf.h:31
bool empty() const
Definition: genericvector.h:90
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer)
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:444
virtual bool DeSerialize(TFile *fp)
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:438
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:426
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:439
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:452
uint8_t uinT8
Definition: host.h:35
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:112

◆ DisplayTargets()

void tesseract::LSTMTrainer::DisplayTargets ( const NetworkIO targets,
const char *  window_name,
ScrollView **  window 
)
protected

Definition at line 1061 of file lstmtrainer.cpp.

1062  {
1063 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics.
1064  int width = targets.Width();
1065  int num_features = targets.NumFeatures();
1066  Network::ClearWindow(true, window_name, width * kTargetXScale, kTargetYScale,
1067  window);
1068  for (int c = 0; c < num_features; ++c) {
1069  int color = c % (ScrollView::GREEN_YELLOW - 1) + 2;
1070  (*window)->Pen(static_cast<ScrollView::Color>(color));
1071  int start_t = -1;
1072  for (int t = 0; t < width; ++t) {
1073  double target = targets.f(t)[c];
1074  target *= kTargetYScale;
1075  if (target >= 1) {
1076  if (start_t < 0) {
1077  (*window)->SetCursor(t - 1, 0);
1078  start_t = t;
1079  }
1080  (*window)->DrawTo(t, target);
1081  } else if (start_t >= 0) {
1082  (*window)->DrawTo(t, 0);
1083  (*window)->DrawTo(start_t - 1, 0);
1084  start_t = -1;
1085  }
1086  }
1087  if (start_t >= 0) {
1088  (*window)->DrawTo(width, 0);
1089  (*window)->DrawTo(start_t - 1, 0);
1090  }
1091  }
1092  (*window)->Update();
1093 #endif // GRAPHICS_DISABLED
1094 }
const int kTargetYScale
Definition: lstmtrainer.cpp:71
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:309
const int kTargetXScale
Definition: lstmtrainer.cpp:70

◆ DumpFilename()

STRING tesseract::LSTMTrainer::DumpFilename ( ) const

Definition at line 951 of file lstmtrainer.cpp.

951  {
954  filename.add_str_int("_", best_iteration_);
955  filename += ".lstm";
956  return filename;
957 }
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
const char * string() const
Definition: strngs.cpp:198
Definition: strngs.h:45
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391
const char * filename
Definition: ioapi.h:38

◆ EmptyConstructor()

void tesseract::LSTMTrainer::EmptyConstructor ( )
protected

Definition at line 967 of file lstmtrainer.cpp.

967  {
968  align_win_ = NULL;
969  target_win_ = NULL;
970  ctc_win_ = NULL;
971  recon_win_ = NULL;
974  training_stage_ = 0;
976  InitIterations();
977 }
ScrollView * ctc_win_
Definition: lstmtrainer.h:394
SerializeAmount serialize_amount_
Definition: lstmtrainer.h:408
ScrollView * target_win_
Definition: lstmtrainer.h:392
ScrollView * recon_win_
Definition: lstmtrainer.h:396
ScrollView * align_win_
Definition: lstmtrainer.h:390

◆ EncodeString() [1/2]

bool tesseract::LSTMTrainer::EncodeString ( const STRING str,
GenericVector< int > *  labels 
) const
inline

Definition at line 247 of file lstmtrainer.h.

247  {
248  return EncodeString(str, GetUnicharset(), IsRecoding() ? &recoder_ : NULL,
249  SimpleTextOutput(), null_char_, labels);
250  }
const UNICHARSET & GetUnicharset() const
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:247

◆ EncodeString() [2/2]

bool tesseract::LSTMTrainer::EncodeString ( const STRING str,
const UNICHARSET unicharset,
const UnicharCompress recoder,
bool  simple_text,
int  null_char,
GenericVector< int > *  labels 
)
static

Definition at line 748 of file lstmtrainer.cpp.

750  {
751  if (str.string() == NULL || str.length() <= 0) {
752  tprintf("Empty truth string!\n");
753  return false;
754  }
755  int err_index;
756  GenericVector<int> internal_labels;
757  labels->truncate(0);
758  if (!simple_text) labels->push_back(null_char);
759  if (unicharset.encode_string(str.string(), true, &internal_labels, NULL,
760  &err_index)) {
761  bool success = true;
762  for (int i = 0; i < internal_labels.size(); ++i) {
763  if (recoder != NULL) {
764  // Re-encode labels via recoder.
765  RecodedCharID code;
766  int len = recoder->EncodeUnichar(internal_labels[i], &code);
767  if (len > 0) {
768  for (int j = 0; j < len; ++j) {
769  labels->push_back(code(j));
770  if (!simple_text) labels->push_back(null_char);
771  }
772  } else {
773  success = false;
774  err_index = 0;
775  break;
776  }
777  } else {
778  labels->push_back(internal_labels[i]);
779  if (!simple_text) labels->push_back(null_char);
780  }
781  }
782  if (success) return true;
783  }
784  tprintf("Encoding of string failed! Failure bytes:");
785  while (err_index < str.length()) {
786  tprintf(" %x", str[err_index++]);
787  }
788  tprintf("\n");
789  return false;
790 }
int push_back(T object)
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
void truncate(int size)
inT32 length() const
Definition: strngs.cpp:193
int size() const
Definition: genericvector.h:72
bool encode_string(const char *str, bool give_up_on_failure, GenericVector< UNICHAR_ID > *encoding, GenericVector< char > *lengths, int *encoded_length) const
Definition: unicharset.cpp:234

◆ error_rates()

const double* tesseract::LSTMTrainer::error_rates ( ) const
inline

Definition at line 140 of file lstmtrainer.h.

140  {
141  return error_rates_;
142  }
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:475

◆ FillErrorBuffer()

void tesseract::LSTMTrainer::FillErrorBuffer ( double  new_error,
ErrorTypes  type 
)

Definition at line 960 of file lstmtrainer.cpp.

960  {
961  for (int i = 0; i < kRollingBufferSize_; ++i)
962  error_buffers_[type][i] = new_error;
963  error_rates_[type] = 100.0 * new_error;
964 }
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:475
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:473
static const int kRollingBufferSize_
Definition: lstmtrainer.h:472

◆ GridSearchDictParams()

Trainability tesseract::LSTMTrainer::GridSearchDictParams ( const ImageData trainingdata,
int  iteration,
double  min_dict_ratio,
double  dict_ratio_step,
double  max_dict_ratio,
double  min_cert_offset,
double  cert_offset_step,
double  max_cert_offset,
STRING results 
)

Definition at line 248 of file lstmtrainer.cpp.

251  {
252  sample_iteration_ = iteration;
253  NetworkIO fwd_outputs, targets;
254  Trainability result =
255  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
256  if (result == UNENCODABLE || result == HI_PRECISION_ERR || dict_ == NULL)
257  return result;
258 
259  // Encode/decode the truth to get the normalization.
260  GenericVector<int> truth_labels, ocr_labels, xcoords;
261  ASSERT_HOST(EncodeString(trainingdata->transcription(), &truth_labels));
262  // NO-dict error.
263  RecodeBeamSearch base_search(recoder_, null_char_, SimpleTextOutput(), NULL);
264  base_search.Decode(fwd_outputs, 1.0, 0.0, RecodeBeamSearch::kMinCertainty,
265  NULL);
266  base_search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
267  STRING truth_text = DecodeLabels(truth_labels);
268  STRING ocr_text = DecodeLabels(ocr_labels);
269  double baseline_error = ComputeWordError(&truth_text, &ocr_text);
270  results->add_str_double("0,0=", baseline_error);
271 
272  RecodeBeamSearch search(recoder_, null_char_, SimpleTextOutput(), dict_);
273  for (double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
274  for (double c = min_cert_offset; c < max_cert_offset;
275  c += cert_offset_step) {
276  search.Decode(fwd_outputs, r, c, RecodeBeamSearch::kMinCertainty, NULL);
277  search.ExtractBestPathAsLabels(&ocr_labels, &xcoords);
278  truth_text = DecodeLabels(truth_labels);
279  ocr_text = DecodeLabels(ocr_labels);
280  // This is destructive on both strings.
281  double word_error = ComputeWordError(&truth_text, &ocr_text);
282  if ((r == min_dict_ratio && c == min_cert_offset) ||
283  !std::isfinite(word_error)) {
284  STRING t = DecodeLabels(truth_labels);
285  STRING o = DecodeLabels(ocr_labels);
286  tprintf("r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
287  t.string(), o.string(), word_error, truth_labels[0]);
288  }
289  results->add_str_double(" ", r);
290  results->add_str_double(",", c);
291  results->add_str_double("=", word_error);
292  }
293  }
294  return result;
295 }
static const float kMinCertainty
Definition: recodebeam.h:213
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
#define ASSERT_HOST(x)
Definition: errcode.h:84
LIST search(LIST list, void *key, int_compare is_equal)
Definition: oldlist.cpp:406
Definition: strngs.h:45
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
STRING DecodeLabels(const GenericVector< int > &labels)
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:247

◆ improvement_steps()

int tesseract::LSTMTrainer::improvement_steps ( ) const
inline

Definition at line 150 of file lstmtrainer.h.

150 { return improvement_steps_; }

◆ InitCharSet() [1/2]

void tesseract::LSTMTrainer::InitCharSet ( const UNICHARSET unicharset,
const STRING script_dir,
int  train_flags 
)

Definition at line 138 of file lstmtrainer.cpp.

139  {
141  training_flags_ = train_flags;
142  ccutil_.unicharset.CopyFrom(unicharset);
144  : GetUnicharset().size();
145  SetUnicharsetProperties(script_dir);
146 }
const UNICHARSET & GetUnicharset() const
bool has_special_codes() const
Definition: unicharset.h:682
UNICHARSET unicharset
Definition: ccutil.h:68
void SetUnicharsetProperties(const STRING &script_dir)
int size() const
Definition: unicharset.h:299
void CopyFrom(const UNICHARSET &src)
Definition: unicharset.cpp:423

◆ InitCharSet() [2/2]

void tesseract::LSTMTrainer::InitCharSet ( const UNICHARSET unicharset,
const UnicharCompress recoder 
)

Definition at line 152 of file lstmtrainer.cpp.

153  {
155  int flags = TF_COMPRESS_UNICHARSET;
156  training_flags_ = static_cast<TrainingFlags>(flags);
157  ccutil_.unicharset.CopyFrom(unicharset);
158  recoder_ = recoder;
160  : GetUnicharset().size();
161  RecodedCharID code;
163  null_char_ = code(0);
164  // Space should encode as itself.
166  ASSERT_HOST(code(0) == UNICHAR_SPACE);
167 }
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
const UNICHARSET & GetUnicharset() const
#define ASSERT_HOST(x)
Definition: errcode.h:84
bool has_special_codes() const
Definition: unicharset.h:682
UNICHARSET unicharset
Definition: ccutil.h:68
int size() const
Definition: unicharset.h:299
void CopyFrom(const UNICHARSET &src)
Definition: unicharset.cpp:423

◆ InitIterations()

void tesseract::LSTMTrainer::InitIterations ( )

Definition at line 223 of file lstmtrainer.cpp.

223  {
224  sample_iteration_ = 0;
228  best_error_rate_ = 100.0;
229  best_iteration_ = 0;
230  worst_error_rate_ = 0.0;
231  worst_iteration_ = 0;
234  perfect_delay_ = 0;
236  for (int i = 0; i < ET_COUNT; ++i) {
237  best_error_rates_[i] = 100.0;
238  worst_error_rates_[i] = 0.0;
240  error_rates_[i] = 100.0;
241  }
243 }
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:475
void init_to_size(int size, T t)
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:473
const int kMinStallIterations
Definition: lstmtrainer.cpp:47
static const int kRollingBufferSize_
Definition: lstmtrainer.h:472
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:446
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:426
const int kMinStartedErrorRate
Definition: lstmtrainer.cpp:60

◆ InitNetwork()

bool tesseract::LSTMTrainer::InitNetwork ( const STRING network_spec,
int  append_index,
int  net_flags,
float  weight_range,
float  learning_rate,
float  momentum 
)

Definition at line 175 of file lstmtrainer.cpp.

177  {
178  // Call after InitCharSet.
180  weight_range_ = weight_range;
182  momentum_ = momentum;
183  int num_outputs = null_char_ == GetUnicharset().size()
184  ? null_char_ + 1
185  : GetUnicharset().size();
186  if (IsRecoding()) num_outputs = recoder_.code_range();
187  if (!NetworkBuilder::InitNetwork(num_outputs, network_spec, append_index,
188  net_flags, weight_range, &randomizer_,
189  &network_)) {
190  return false;
191  }
192  network_str_ += network_spec;
193  tprintf("Built network:%s from request %s\n",
194  network_->spec().string(), network_spec.string());
195  tprintf("Training parameters:\n Debug interval = %d,"
196  " weights = %g, learning rate = %g, momentum=%g\n",
198  return true;
199 }
voidpf void uLong size
Definition: ioapi.h:39
const UNICHARSET & GetUnicharset() const
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
#define ASSERT_HOST(x)
Definition: errcode.h:84
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
double learning_rate() const
int size() const
Definition: unicharset.h:299
virtual STRING spec() const
Definition: network.h:141

◆ InitTensorFlowNetwork()

int tesseract::LSTMTrainer::InitTensorFlowNetwork ( const std::string &  tf_proto)

Definition at line 203 of file lstmtrainer.cpp.

203  {
204 #ifdef INCLUDE_TENSORFLOW
205  delete network_;
206  TFNetwork* tf_net = new TFNetwork("TensorFlow");
207  training_iteration_ = tf_net->InitFromProtoStr(tf_proto);
208  if (training_iteration_ == 0) {
209  tprintf("InitFromProtoStr failed!!\n");
210  return 0;
211  }
212  network_ = tf_net;
213  ASSERT_HOST(recoder_.code_range() == tf_net->num_classes());
214  return training_iteration_;
215 #else
216  tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
217  return 0;
218 #endif
219 }
#define tprintf(...)
Definition: tprintf.h:31
#define ASSERT_HOST(x)
Definition: errcode.h:84

◆ LastSingleError()

double tesseract::LSTMTrainer::LastSingleError ( ErrorTypes  type) const
inline

Definition at line 160 of file lstmtrainer.h.

160  {
161  return error_buffers_[type]
164  }
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:473
static const int kRollingBufferSize_
Definition: lstmtrainer.h:472

◆ learning_iteration()

int tesseract::LSTMTrainer::learning_iteration ( ) const
inline

Definition at line 149 of file lstmtrainer.h.

149 { return learning_iteration_; }

◆ LoadAllTrainingData()

bool tesseract::LSTMTrainer::LoadAllTrainingData ( const GenericVector< STRING > &  filenames)

Definition at line 305 of file lstmtrainer.cpp.

305  {
308 }
DocumentCache training_data_
Definition: lstmtrainer.h:406
CachingStrategy CacheStrategy() const
bool LoadDocuments(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, FileReader reader)
Definition: imagedata.cpp:572

◆ LogIterations()

void tesseract::LSTMTrainer::LogIterations ( const char *  intro_str,
STRING log_msg 
) const

Definition at line 414 of file lstmtrainer.cpp.

414  {
415  *log_msg += intro_str;
416  log_msg->add_str_int(" iteration ", learning_iteration());
417  log_msg->add_str_int("/", training_iteration());
418  log_msg->add_str_int("/", sample_iteration());
419 }
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
int learning_iteration() const
Definition: lstmtrainer.h:149

◆ MaintainCheckpoints()

bool tesseract::LSTMTrainer::MaintainCheckpoints ( TestCallback  tester,
STRING log_msg 
)

Definition at line 314 of file lstmtrainer.cpp.

314  {
315  PrepareLogMsg(log_msg);
316  double error_rate = CharError();
317  int iteration = learning_iteration();
318  if (iteration >= stall_iteration_ &&
319  error_rate > best_error_rate_ * (1.0 + kSubTrainerMarginFraction) &&
321  // It hasn't got any better in a long while, and is a margin worse than the
322  // best, so go back to the best model and try a different learning rate.
323  StartSubtrainer(log_msg);
324  }
325  SubTrainerResult sub_trainer_result = STR_NONE;
326  if (sub_trainer_ != NULL) {
327  sub_trainer_result = UpdateSubtrainer(log_msg);
328  if (sub_trainer_result == STR_REPLACED) {
329  // Reset the inputs, as we have overwritten *this.
330  error_rate = CharError();
331  iteration = learning_iteration();
332  PrepareLogMsg(log_msg);
333  }
334  }
335  bool result = true; // Something interesting happened.
336  GenericVector<char> rec_model_data;
337  if (error_rate < best_error_rate_) {
338  SaveRecognitionDump(&rec_model_data);
339  log_msg->add_str_double(" New best char error = ", error_rate);
340  *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
341  // If sub_trainer_ is not NULL, either *this beat it to a new best, or it
342  // just overwrote *this. In either case, we have finished with it.
343  delete sub_trainer_;
344  sub_trainer_ = NULL;
347  log_msg->add_str_int(" Transitioned to stage ", CurrentTrainingStage());
348  }
351  STRING best_model_name = DumpFilename();
352  if (!(*file_writer_)(best_trainer_, best_model_name)) {
353  *log_msg += " failed to write best model:";
354  } else {
355  *log_msg += " wrote best model:";
357  }
358  *log_msg += best_model_name;
359  }
360  } else if (error_rate > worst_error_rate_) {
361  SaveRecognitionDump(&rec_model_data);
362  log_msg->add_str_double(" New worst char error = ", error_rate);
363  *log_msg += UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
366  // Error rate has ballooned. Go back to the best model.
367  *log_msg += "\nDivergence! ";
368  // Copy best_trainer_ before reading it, as it will get overwritten.
369  GenericVector<char> revert_data(best_trainer_);
370  if (checkpoint_reader_->Run(revert_data, this)) {
371  LogIterations("Reverted to", log_msg);
372  ReduceLearningRates(this, log_msg);
373  } else {
374  LogIterations("Failed to Revert at", log_msg);
375  }
376  // If it fails again, we will wait twice as long before reverting again.
377  stall_iteration_ = iteration + 2 * (iteration - learning_iteration());
378  // Re-save the best trainer with the new learning rates and stall
379  // iteration.
381  }
382  } else {
383  // Something interesting happened only if the sub_trainer_ was trained.
384  result = sub_trainer_result != STR_NONE;
385  }
386  if (checkpoint_writer_ != NULL && file_writer_ != NULL &&
387  checkpoint_name_.length() > 0) {
388  // Write a current checkpoint.
389  GenericVector<char> checkpoint;
390  if (!checkpoint_writer_->Run(FULL, this, &checkpoint) ||
391  !(*file_writer_)(checkpoint, checkpoint_name_)) {
392  *log_msg += " failed to write checkpoint.";
393  } else {
394  *log_msg += " wrote checkpoint.";
395  }
396  }
397  *log_msg += "\n";
398  return result;
399 }
void PrepareLogMsg(STRING *log_msg) const
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
int CurrentTrainingStage() const
Definition: lstmtrainer.h:213
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:441
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:419
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
const int kMinStallIterations
Definition: lstmtrainer.cpp:47
const double kMinDivergenceRate
Definition: lstmtrainer.cpp:45
virtual R Run(A1, A2)=0
bool empty() const
Definition: genericvector.h:90
void LogIterations(const char *intro_str, STRING *log_msg) const
inT32 length() const
Definition: strngs.cpp:193
const double kSubTrainerMarginFraction
Definition: lstmtrainer.cpp:50
const double kStageTransitionThreshold
Definition: lstmtrainer.cpp:62
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:444
float error_rate_of_last_saved_best_
Definition: lstmtrainer.h:446
void StartSubtrainer(STRING *log_msg)
Definition: strngs.h:45
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:418
int learning_iteration() const
Definition: lstmtrainer.h:149
STRING DumpFilename() const
virtual R Run(A1, A2, A3)=0
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
const int kMinStartedErrorRate
Definition: lstmtrainer.cpp:60
double CharError() const
Definition: lstmtrainer.h:139
bool TransitionTrainingStage(float error_threshold)
const double kBestCheckpointFraction
Definition: lstmtrainer.cpp:68
void SaveRecognitionDump(GenericVector< char > *data) const

◆ MaintainCheckpointsSpecific()

bool tesseract::LSTMTrainer::MaintainCheckpointsSpecific ( int  iteration,
const GenericVector< char > *  train_model,
const GenericVector< char > *  rec_model,
TestCallback  tester,
STRING log_msg 
)

◆ mutable_training_data()

DocumentCache* tesseract::LSTMTrainer::mutable_training_data ( )
inline

Definition at line 168 of file lstmtrainer.h.

168 { return &training_data_; }
DocumentCache training_data_
Definition: lstmtrainer.h:406

◆ NewSingleError()

double tesseract::LSTMTrainer::NewSingleError ( ErrorTypes  type) const
inline

Definition at line 154 of file lstmtrainer.h.

154  {
156  }
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:473
static const int kRollingBufferSize_
Definition: lstmtrainer.h:472

◆ PrepareForBackward()

Trainability tesseract::LSTMTrainer::PrepareForBackward ( const ImageData trainingdata,
NetworkIO fwd_outputs,
NetworkIO targets 
)

Definition at line 827 of file lstmtrainer.cpp.

829  {
830  if (trainingdata == NULL) {
831  tprintf("Null trainingdata.\n");
832  return UNENCODABLE;
833  }
834  // Ensure repeatability of random elements even across checkpoints.
835  bool debug = debug_interval_ > 0 &&
837  GenericVector<int> truth_labels;
838  if (!EncodeString(trainingdata->transcription(), &truth_labels)) {
839  tprintf("Can't encode transcription: %s\n",
840  trainingdata->transcription().string());
841  return UNENCODABLE;
842  }
843  int w = 0;
844  while (w < truth_labels.size() &&
845  (truth_labels[w] == UNICHAR_SPACE || truth_labels[w] == null_char_))
846  ++w;
847  if (w == truth_labels.size()) {
848  tprintf("Blank transcription: %s\n",
849  trainingdata->transcription().string());
850  return UNENCODABLE;
851  }
852  float image_scale;
853  NetworkIO inputs;
854  bool invert = trainingdata->boxes().empty();
855  if (!RecognizeLine(*trainingdata, invert, debug, invert, 0.0f, &image_scale,
856  &inputs, fwd_outputs)) {
857  tprintf("Image not trainable\n");
858  return UNENCODABLE;
859  }
860  targets->Resize(*fwd_outputs, network_->NumOutputs());
861  LossType loss_type = OutputLossType();
862  if (loss_type == LT_SOFTMAX) {
863  if (!ComputeTextTargets(*fwd_outputs, truth_labels, targets)) {
864  tprintf("Compute simple targets failed!\n");
865  return UNENCODABLE;
866  }
867  } else if (loss_type == LT_CTC) {
868  if (!ComputeCTCTargets(truth_labels, fwd_outputs, targets)) {
869  tprintf("Compute CTC targets failed!\n");
870  return UNENCODABLE;
871  }
872  } else {
873  tprintf("Logistic outputs not implemented yet!\n");
874  return UNENCODABLE;
875  }
876  GenericVector<int> ocr_labels;
877  GenericVector<int> xcoords;
878  LabelsFromOutputs(*fwd_outputs, 0.0f, &ocr_labels, &xcoords);
879  // CTC does not produce correct target labels to begin with.
880  if (loss_type != LT_CTC) {
881  LabelsFromOutputs(*targets, 0.0f, &truth_labels, &xcoords);
882  }
883  if (!DebugLSTMTraining(inputs, *trainingdata, *fwd_outputs, truth_labels,
884  *targets)) {
885  tprintf("Input width was %d\n", inputs.Width());
886  return UNENCODABLE;
887  }
888  STRING ocr_text = DecodeLabels(ocr_labels);
889  STRING truth_text = DecodeLabels(truth_labels);
890  targets->SubtractAllFromFloat(*fwd_outputs);
891  if (debug_interval_ != 0) {
892  tprintf("Iteration %d: BEST OCR TEXT : %s\n", training_iteration(),
893  ocr_text.string());
894  }
895  double char_error = ComputeCharError(truth_labels, ocr_labels);
896  double word_error = ComputeWordError(&truth_text, &ocr_text);
897  double delta_error = ComputeErrorRates(*targets, char_error, word_error);
898  if (debug_interval_ != 0) {
899  tprintf("File %s page %d %s:\n", trainingdata->imagefilename().string(),
900  trainingdata->page_number(), delta_error == 0.0 ? "(Perfect)" : "");
901  }
902  if (delta_error == 0.0) return PERFECT;
903  if (targets->AnySuspiciousTruth(kHighConfidence)) return HI_PRECISION_ERR;
904  return TRAINABLE;
905 }
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, bool use_alternates, const UNICHARSET *target_unicharset, const TBOX &line_box, float score_ratio, bool one_word, PointerVector< WERD_RES > *words)
LossType OutputLossType() const
const double kHighConfidence
Definition: lstmtrainer.cpp:64
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
int size() const
Definition: genericvector.h:72
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
Definition: strngs.h:45
void LabelsFromOutputs(const NetworkIO &outputs, float null_thr, GenericVector< int > *labels, GenericVector< int > *xcoords)
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
STRING DecodeLabels(const GenericVector< int > &labels)
int NumOutputs() const
Definition: network.h:123
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
Definition: lstmtrainer.h:247

◆ PrepareLogMsg()

void tesseract::LSTMTrainer::PrepareLogMsg ( STRING log_msg) const

Definition at line 402 of file lstmtrainer.cpp.

402  {
403  LogIterations("At", log_msg);
404  log_msg->add_str_double(", Mean rms=", error_rates_[ET_RMS]);
405  log_msg->add_str_double("%, delta=", error_rates_[ET_DELTA]);
406  log_msg->add_str_double("%, char train=", error_rates_[ET_CHAR_ERROR]);
407  log_msg->add_str_double("%, word train=", error_rates_[ET_WORD_RECERR]);
408  log_msg->add_str_double("%, skip ratio=", error_rates_[ET_SKIP_RATIO]);
409  *log_msg += "%, ";
410 }
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:475
void LogIterations(const char *intro_str, STRING *log_msg) const
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391

◆ ReadRecognitionDump()

LSTMRecognizer * tesseract::LSTMTrainer::ReadRecognitionDump ( const GenericVector< char > &  data)
static

Definition at line 940 of file lstmtrainer.cpp.

941  {
942  TFile fp;
943  fp.Open(&data[0], data.size());
944  LSTMRecognizer* recognizer = new LSTMRecognizer;
945  ASSERT_HOST(recognizer->DeSerialize(&fp));
946  return recognizer;
947 }
int size() const
Definition: genericvector.h:72
#define ASSERT_HOST(x)
Definition: errcode.h:84

◆ ReadSizedTrainingDump()

bool tesseract::LSTMTrainer::ReadSizedTrainingDump ( const char *  data,
int  size 
)

Definition at line 924 of file lstmtrainer.cpp.

924  {
925  TFile fp;
926  fp.Open(data, size);
927  return DeSerialize(&fp);
928 }
voidpf void uLong size
Definition: ioapi.h:39
virtual bool DeSerialize(TFile *fp)

◆ ReadTrainingDump()

bool tesseract::LSTMTrainer::ReadTrainingDump ( const GenericVector< char > &  data,
LSTMTrainer trainer 
)

Definition at line 919 of file lstmtrainer.cpp.

920  {
921  return trainer->ReadSizedTrainingDump(&data[0], data.size());
922 }
int size() const
Definition: genericvector.h:72

◆ ReduceLayerLearningRates()

int tesseract::LSTMTrainer::ReduceLayerLearningRates ( double  factor,
int  num_samples,
LSTMTrainer samples_trainer 
)

Definition at line 639 of file lstmtrainer.cpp.

640  {
641  enum WhichWay {
642  LR_DOWN, // Learning rate will go down by factor.
643  LR_SAME, // Learning rate will stay the same.
644  LR_COUNT // Size of arrays.
645  };
646  // Epsilon is so small that it may as well be zero, but still positive.
647  const double kEpsilon = 1.0e-30;
649  int num_layers = layers.size();
650  GenericVector<int> num_weights;
651  num_weights.init_to_size(num_layers, 0);
652  GenericVector<double> bad_sums[LR_COUNT];
653  GenericVector<double> ok_sums[LR_COUNT];
654  for (int i = 0; i < LR_COUNT; ++i) {
655  bad_sums[i].init_to_size(num_layers, 0.0);
656  ok_sums[i].init_to_size(num_layers, 0.0);
657  }
658  double momentum_factor = 1.0 / (1.0 - momentum_);
659  GenericVector<char> orig_trainer;
660  SaveTrainingDump(LIGHT, this, &orig_trainer);
661  for (int i = 0; i < num_layers; ++i) {
662  Network* layer = GetLayer(layers[i]);
663  num_weights[i] = layer->IsTraining() ? layer->num_weights() : 0;
664  }
665  int iteration = sample_iteration();
666  for (int s = 0; s < num_samples; ++s) {
667  // Which way will we modify the learning rate?
668  for (int ww = 0; ww < LR_COUNT; ++ww) {
669  // Transfer momentum to learning rate and adjust by the ww factor.
670  float ww_factor = momentum_factor;
671  if (ww == LR_DOWN) ww_factor *= factor;
672  // Make a copy of *this, so we can mess about without damaging anything.
673  LSTMTrainer copy_trainer;
674  copy_trainer.ReadTrainingDump(orig_trainer, &copy_trainer);
675  // Clear the updates, doing nothing else.
676  copy_trainer.network_->Update(0.0, 0.0, 0);
677  // Adjust the learning rate in each layer.
678  for (int i = 0; i < num_layers; ++i) {
679  if (num_weights[i] == 0) continue;
680  copy_trainer.ScaleLayerLearningRate(layers[i], ww_factor);
681  }
682  copy_trainer.SetIteration(iteration);
683  // Train on the sample, but keep the update in updates_ instead of
684  // applying to the weights.
685  const ImageData* trainingdata =
686  copy_trainer.TrainOnLine(samples_trainer, true);
687  if (trainingdata == NULL) continue;
688  // We'll now use this trainer again for each layer.
689  GenericVector<char> updated_trainer;
690  SaveTrainingDump(LIGHT, &copy_trainer, &updated_trainer);
691  for (int i = 0; i < num_layers; ++i) {
692  if (num_weights[i] == 0) continue;
693  LSTMTrainer layer_trainer;
694  layer_trainer.ReadTrainingDump(updated_trainer, &layer_trainer);
695  Network* layer = layer_trainer.GetLayer(layers[i]);
696  // Update the weights in just the layer, and also zero the updates
697  // matrix (to epsilon).
698  layer->Update(0.0, kEpsilon, 0);
699  // Train again on the same sample, again holding back the updates.
700  layer_trainer.TrainOnLine(trainingdata, true);
701  // Count the sign changes in the updates in layer vs in copy_trainer.
702  float before_bad = bad_sums[ww][i];
703  float before_ok = ok_sums[ww][i];
704  layer->CountAlternators(*copy_trainer.GetLayer(layers[i]),
705  &ok_sums[ww][i], &bad_sums[ww][i]);
706  float bad_frac =
707  bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
708  if (bad_frac > 0.0f)
709  bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
710  }
711  }
712  ++iteration;
713  }
714  int num_lowered = 0;
715  for (int i = 0; i < num_layers; ++i) {
716  if (num_weights[i] == 0) continue;
717  Network* layer = GetLayer(layers[i]);
718  float lr = GetLayerLearningRate(layers[i]);
719  double total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
720  double total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
721  double frac_down = bad_sums[LR_DOWN][i] / total_down;
722  double frac_same = bad_sums[LR_SAME][i] / total_same;
723  tprintf("Layer %d=%s: lr %g->%g%%, lr %g->%g%%", i, layer->name().string(),
724  lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
725  if (frac_down < frac_same * kImprovementFraction) {
726  tprintf(" REDUCED\n");
727  ScaleLayerLearningRate(layers[i], factor);
728  ++num_lowered;
729  } else {
730  tprintf(" SAME\n");
731  }
732  }
733  if (num_lowered == 0) {
734  // Just lower everything to make sure.
735  for (int i = 0; i < num_layers; ++i) {
736  if (num_weights[i] > 0) {
737  ScaleLayerLearningRate(layers[i], factor);
738  ++num_lowered;
739  }
740  }
741  }
742  return num_lowered;
743 }
float GetLayerLearningRate(const STRING &id) const
void init_to_size(int size, T t)
#define tprintf(...)
Definition: tprintf.h:31
GenericVector< STRING > EnumerateLayers() const
const double kImprovementFraction
Definition: lstmtrainer.cpp:66
int size() const
Definition: genericvector.h:72
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
Network * GetLayer(const STRING &id) const
void ScaleLayerLearningRate(const STRING &id, double factor)

◆ ReduceLearningRates()

void tesseract::LSTMTrainer::ReduceLearningRates ( LSTMTrainer samples_trainer,
STRING log_msg 
)

Definition at line 620 of file lstmtrainer.cpp.

621  {
623  int num_reduced = ReduceLayerLearningRates(
624  kLearningRateDecay, kNumAdjustmentIterations, samples_trainer);
625  log_msg->add_str_int("\nReduced learning rate on layers: ", num_reduced);
626  } else {
628  log_msg->add_str_double("\nReduced learning rate to :", learning_rate_);
629  }
630  *log_msg += "\n";
631 }
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
const int kNumAdjustmentIterations
Definition: lstmtrainer.cpp:54
void ScaleLearningRate(double factor)
const double kLearningRateDecay
Definition: lstmtrainer.cpp:52
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)

◆ RollErrorBuffers()

void tesseract::LSTMTrainer::RollErrorBuffers ( )
protected

Definition at line 1260 of file lstmtrainer.cpp.

1260  {
1262  if (NewSingleError(ET_DELTA) > 0.0)
1264  else
1267  if (debug_interval_ != 0) {
1268  tprintf("Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1272  }
1273 }
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:475
#define tprintf(...)
Definition: tprintf.h:31
double NewSingleError(ErrorTypes type) const
Definition: lstmtrainer.h:154

◆ SaveBestModel()

bool tesseract::LSTMTrainer::SaveBestModel ( FileWriter  writer) const

◆ SaveRecognitionDump()

void tesseract::LSTMTrainer::SaveRecognitionDump ( GenericVector< char > *  data) const

Definition at line 931 of file lstmtrainer.cpp.

931  {
932  TFile fp;
933  fp.OpenWrite(data);
937 }
#define ASSERT_HOST(x)
Definition: errcode.h:84
bool Serialize(TFile *fp) const
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:112

◆ SaveTrainingDump()

bool tesseract::LSTMTrainer::SaveTrainingDump ( SerializeAmount  serialize_amount,
const LSTMTrainer trainer,
GenericVector< char > *  data 
) const

Definition at line 909 of file lstmtrainer.cpp.

911  {
912  TFile fp;
913  fp.OpenWrite(data);
914  trainer->serialize_amount_ = serialize_amount;
915  return trainer->Serialize(&fp);
916 }

◆ Serialize()

bool tesseract::LSTMTrainer::Serialize ( TFile fp) const
virtual

Definition at line 433 of file lstmtrainer.cpp.

433  {
434  if (!LSTMRecognizer::Serialize(fp)) return false;
435  if (fp->FWrite(&learning_iteration_, sizeof(learning_iteration_), 1) != 1)
436  return false;
437  if (fp->FWrite(&prev_sample_iteration_, sizeof(prev_sample_iteration_), 1) !=
438  1)
439  return false;
440  if (fp->FWrite(&perfect_delay_, sizeof(perfect_delay_), 1) != 1) return false;
441  if (fp->FWrite(&last_perfect_training_iteration_,
442  sizeof(last_perfect_training_iteration_), 1) != 1)
443  return false;
444  for (int i = 0; i < ET_COUNT; ++i) {
445  if (!error_buffers_[i].Serialize(fp)) return false;
446  }
447  if (fp->FWrite(&error_rates_, sizeof(error_rates_), 1) != 1) return false;
448  if (fp->FWrite(&training_stage_, sizeof(training_stage_), 1) != 1)
449  return false;
450  uinT8 amount = serialize_amount_;
451  if (fp->FWrite(&amount, sizeof(amount), 1) != 1) return false;
452  if (amount == LIGHT) return true; // We are done.
453  if (fp->FWrite(&best_error_rate_, sizeof(best_error_rate_), 1) != 1)
454  return false;
455  if (fp->FWrite(&best_error_rates_, sizeof(best_error_rates_), 1) != 1)
456  return false;
457  if (fp->FWrite(&best_iteration_, sizeof(best_iteration_), 1) != 1)
458  return false;
459  if (fp->FWrite(&worst_error_rate_, sizeof(worst_error_rate_), 1) != 1)
460  return false;
461  if (fp->FWrite(&worst_error_rates_, sizeof(worst_error_rates_), 1) != 1)
462  return false;
463  if (fp->FWrite(&worst_iteration_, sizeof(worst_iteration_), 1) != 1)
464  return false;
465  if (fp->FWrite(&stall_iteration_, sizeof(stall_iteration_), 1) != 1)
466  return false;
467  if (!best_model_data_.Serialize(fp)) return false;
468  if (!worst_model_data_.Serialize(fp)) return false;
469  if (amount != NO_BEST_TRAINER && !best_trainer_.Serialize(fp)) return false;
470  GenericVector<char> sub_data;
471  if (sub_trainer_ != NULL && !SaveTrainingDump(LIGHT, sub_trainer_, &sub_data))
472  return false;
473  if (!sub_data.Serialize(fp)) return false;
474  if (!best_error_history_.Serialize(fp)) return false;
475  if (!best_error_iterations_.Serialize(fp)) return false;
476  if (fp->FWrite(&improvement_steps_, sizeof(improvement_steps_), 1) != 1)
477  return false;
478  return true;
479 }
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:441
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:475
virtual bool Serialize(TFile *fp) const
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:473
SerializeAmount serialize_amount_
Definition: lstmtrainer.h:408
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:451
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:444
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:438
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:426
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:439
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:452
uint8_t uinT8
Definition: host.h:35
bool Serialize(FILE *fp) const
bool Serialize(TFile *fp) const

◆ set_perfect_delay()

void tesseract::LSTMTrainer::set_perfect_delay ( int  delay)
inline

Definition at line 151 of file lstmtrainer.h.

151 { perfect_delay_ = delay; }

◆ SetSerializeMode()

void tesseract::LSTMTrainer::SetSerializeMode ( SerializeAmount  serialize_amount) const
inline

Definition at line 178 of file lstmtrainer.h.

178  {
179  serialize_amount_ = serialize_amount;
180  }
SerializeAmount serialize_amount_
Definition: lstmtrainer.h:408

◆ SetUnicharsetProperties()

void tesseract::LSTMTrainer::SetUnicharsetProperties ( const STRING script_dir)
protected

Definition at line 982 of file lstmtrainer.cpp.

982  {
983  tprintf("Setting unichar properties\n");
984  for (int s = 0; s < GetUnicharset().get_script_table_size(); ++s) {
985  if (strcmp("NULL", GetUnicharset().get_script_from_script_id(s)) == 0)
986  continue;
987  // Load the unicharset for the script if available.
988  STRING filename = script_dir + "/" +
990  ".unicharset";
991  UNICHARSET script_set;
992  GenericVector<char> data;
993  if ((*file_reader_)(filename, &data) &&
994  script_set.load_from_inmemory_file(&data[0], data.size())) {
995  tprintf("Setting properties for script %s\n",
996  GetUnicharset().get_script_from_script_id(s));
998  }
999  }
1000  if (IsRecoding()) {
1001  STRING filename = script_dir + "/radical-stroke.txt";
1002  GenericVector<char> data;
1003  if ((*file_reader_)(filename, &data)) {
1004  data += '\0';
1005  STRING stroke_table = &data[0];
1007  &stroke_table)) {
1008  RecodedCharID code;
1010  null_char_ = code(0);
1011  // Space should encode as itself.
1013  ASSERT_HOST(code(0) == UNICHAR_SPACE);
1014  return;
1015  }
1016  } else {
1017  tprintf("Failed to load radical-stroke info from: %s\n",
1018  filename.string());
1019  }
1021  }
1022 }
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
const char * get_script_from_script_id(int id) const
Definition: unicharset.h:814
const UNICHARSET & GetUnicharset() const
#define tprintf(...)
Definition: tprintf.h:31
void SetPropertiesFromOther(const UNICHARSET &src)
Definition: unicharset.h:505
const char * string() const
Definition: strngs.cpp:198
int size() const
Definition: genericvector.h:72
#define ASSERT_HOST(x)
Definition: errcode.h:84
bool ComputeEncoding(const UNICHARSET &unicharset, int null_id, STRING *radical_stroke_table)
int get_script_table_size() const
Definition: unicharset.h:809
Definition: strngs.h:45
UNICHARSET unicharset
Definition: ccutil.h:68
const char * filename
Definition: ioapi.h:38
bool load_from_inmemory_file(const char *const memory, int mem_size, bool skip_fragments)
Definition: unicharset.cpp:724

◆ SetupCheckpointInfo()

void tesseract::LSTMTrainer::SetupCheckpointInfo ( )

◆ StartSubtrainer()

void tesseract::LSTMTrainer::StartSubtrainer ( STRING log_msg)

Definition at line 547 of file lstmtrainer.cpp.

547  {
548  delete sub_trainer_;
549  sub_trainer_ = new LSTMTrainer();
551  *log_msg += " Failed to revert to previous best for trial!";
552  delete sub_trainer_;
553  sub_trainer_ = NULL;
554  } else {
555  log_msg->add_str_int(" Trial sub_trainer_ from iteration ",
557  // Reduce learning rate so it doesn't diverge this time.
558  sub_trainer_->ReduceLearningRates(this, log_msg);
559  // If it fails again, we will wait twice as long before reverting again.
560  int stall_offset =
562  stall_iteration_ = learning_iteration() + 2 * stall_offset;
564  // Re-save the best trainer with the new learning rates and stall iteration.
566  }
567 }
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
GenericVector< char > best_trainer_
Definition: lstmtrainer.h:441
CheckPointWriter checkpoint_writer_
Definition: lstmtrainer.h:419
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
virtual R Run(A1, A2)=0
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:444
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:418
int learning_iteration() const
Definition: lstmtrainer.h:149
virtual R Run(A1, A2, A3)=0

◆ training_data()

const DocumentCache& tesseract::LSTMTrainer::training_data ( ) const
inline

Definition at line 165 of file lstmtrainer.h.

165  {
166  return training_data_;
167  }
DocumentCache training_data_
Definition: lstmtrainer.h:406

◆ TrainOnLine() [1/2]

const ImageData* tesseract::LSTMTrainer::TrainOnLine ( LSTMTrainer samples_trainer,
bool  batch 
)
inline

Definition at line 268 of file lstmtrainer.h.

268  {
269  int sample_index = sample_iteration();
270  const ImageData* image =
271  samples_trainer->training_data_.GetPageBySerial(sample_index);
272  if (image != NULL) {
273  Trainability trainable = TrainOnLine(image, batch);
274  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
275  return NULL; // Sample was unusable.
276  }
277  } else {
279  }
280  return image;
281  }
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:268

◆ TrainOnLine() [2/2]

Trainability tesseract::LSTMTrainer::TrainOnLine ( const ImageData trainingdata,
bool  batch 
)

Definition at line 794 of file lstmtrainer.cpp.

795  {
796  NetworkIO fwd_outputs, targets;
797  Trainability trainable =
798  PrepareForBackward(trainingdata, &fwd_outputs, &targets);
800  if (trainable == UNENCODABLE || trainable == NOT_BOXED) {
801  return trainable; // Sample was unusable.
802  }
803  bool debug = debug_interval_ > 0 &&
805  // Run backprop on the output.
806  NetworkIO bp_deltas;
807  if (network_->IsTraining() &&
808  (trainable != PERFECT ||
811  network_->Backward(debug, targets, &scratch_space_, &bp_deltas);
812  network_->Update(learning_rate_, batch ? -1.0f : momentum_,
813  training_iteration_ + 1);
814  }
815 #ifndef GRAPHICS_DISABLED
816  if (debug_interval_ == 1 && debug_win_ != NULL) {
818  }
819 #endif // GRAPHICS_DISABLED
820  // Roll the memory of past means.
822  return trainable;
823 }
virtual void Update(float learning_rate, float momentum, int num_samples)
Definition: network.h:218
NetworkScratch scratch_space_
bool IsTraining() const
Definition: network.h:115
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
Definition: network.h:259
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
SVEvent * AwaitEvent(SVEventType type)
Definition: scrollview.cpp:449

◆ TransitionTrainingStage()

bool tesseract::LSTMTrainer::TransitionTrainingStage ( float  error_threshold)

Definition at line 423 of file lstmtrainer.cpp.

423  {
424  if (best_error_rate_ < error_threshold &&
426  ++training_stage_;
427  return true;
428  }
429  return false;
430 }

◆ TryLoadingCheckpoint()

bool tesseract::LSTMTrainer::TryLoadingCheckpoint ( const char *  filename)

Definition at line 125 of file lstmtrainer.cpp.

125  {
126  GenericVector<char> data;
127  if (!(*file_reader_)(filename, &data)) return false;
128  tprintf("Loaded file %s, unpacking...\n", filename);
129  return checkpoint_reader_->Run(data, this);
130 }
virtual R Run(A1, A2)=0
#define tprintf(...)
Definition: tprintf.h:31
CheckPointReader checkpoint_reader_
Definition: lstmtrainer.h:418
const char * filename
Definition: ioapi.h:38

◆ UpdateErrorBuffer()

void tesseract::LSTMTrainer::UpdateErrorBuffer ( double  new_error,
ErrorTypes  type 
)
protected

Definition at line 1247 of file lstmtrainer.cpp.

1247  {
1249  error_buffers_[type][index] = new_error;
1250  // Compute the mean error.
1251  int mean_count = MIN(training_iteration_ + 1, error_buffers_[type].size());
1252  double buffer_sum = 0.0;
1253  for (int i = 0; i < mean_count; ++i) buffer_sum += error_buffers_[type][i];
1254  double mean = buffer_sum / mean_count;
1255  // Trim precision to 1/1000 of 1%.
1256  error_rates_[type] = IntCastRounded(100000.0 * mean) / 1000.0;
1257 }
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:475
GenericVector< double > error_buffers_[ET_COUNT]
Definition: lstmtrainer.h:473
voidpf void uLong size
Definition: ioapi.h:39
static const int kRollingBufferSize_
Definition: lstmtrainer.h:472
int IntCastRounded(double x)
Definition: helpers.h:179
#define MIN(x, y)
Definition: ndminx.h:28

◆ UpdateErrorGraph()

STRING tesseract::LSTMTrainer::UpdateErrorGraph ( int  iteration,
double  error_rate,
const GenericVector< char > &  model_data,
TestCallback  tester 
)
protected

Definition at line 1279 of file lstmtrainer.cpp.

1281  {
1282  if (error_rate > best_error_rate_
1283  && iteration < best_iteration_ + kErrorGraphInterval) {
1284  // Too soon to record a new point.
1285  if (tester != NULL)
1286  return tester->Run(worst_iteration_, NULL, worst_model_data_,
1288  else
1289  return "";
1290  }
1291  STRING result;
1292  // NOTE: there are 2 asymmetries here:
1293  // 1. We are computing the global minimum, but the local maximum in between.
1294  // 2. If the tester returns an empty string, indicating that it is busy,
1295  // call it repeatedly on new local maxima to test the previous min, but
1296  // not the other way around, as there is little point testing the maxima
1297  // between very frequent minima.
1298  if (error_rate < best_error_rate_) {
1299  // This is a new (global) minimum.
1300  if (tester != NULL) {
1301  result = tester->Run(worst_iteration_, worst_error_rates_,
1304  best_model_data_ = model_data;
1305  }
1306  best_error_rate_ = error_rate;
1307  memcpy(best_error_rates_, error_rates_, sizeof(error_rates_));
1308  best_iteration_ = iteration;
1309  best_error_history_.push_back(error_rate);
1310  best_error_iterations_.push_back(iteration);
1311  // Compute 2% decay time.
1312  double two_percent_more = error_rate + 2.0;
1313  int i;
1314  for (i = best_error_history_.size() - 1;
1315  i >= 0 && best_error_history_[i] < two_percent_more; --i) {
1316  }
1317  int old_iteration = i >= 0 ? best_error_iterations_[i] : 0;
1318  improvement_steps_ = iteration - old_iteration;
1319  tprintf("2 Percent improvement time=%d, best error was %g @ %d\n",
1320  improvement_steps_, i >= 0 ? best_error_history_[i] : 100.0,
1321  old_iteration);
1322  } else if (error_rate > best_error_rate_) {
1323  // This is a new (local) maximum.
1324  if (tester != NULL) {
1325  if (best_model_data_.empty()) {
1326  // Allow for multiple data points with "worst" error rate.
1327  result = tester->Run(worst_iteration_, worst_error_rates_,
1329  } else {
1330  result = tester->Run(best_iteration_, best_error_rates_,
1332  }
1333  if (result.length() > 0)
1335  worst_model_data_ = model_data;
1336  }
1337  }
1338  worst_error_rate_ = error_rate;
1339  memcpy(worst_error_rates_, error_rates_, sizeof(error_rates_));
1340  worst_iteration_ = iteration;
1341  return result;
1342 }
int CurrentTrainingStage() const
Definition: lstmtrainer.h:213
double worst_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:432
double error_rates_[ET_COUNT]
Definition: lstmtrainer.h:475
GenericVector< double > best_error_history_
Definition: lstmtrainer.h:451
int push_back(T object)
#define tprintf(...)
Definition: tprintf.h:31
bool empty() const
Definition: genericvector.h:90
void truncate(int size)
inT32 length() const
Definition: strngs.cpp:193
int size() const
Definition: genericvector.h:72
Definition: strngs.h:45
GenericVector< char > best_model_data_
Definition: lstmtrainer.h:438
double best_error_rates_[ET_COUNT]
Definition: lstmtrainer.h:426
GenericVector< char > worst_model_data_
Definition: lstmtrainer.h:439
GenericVector< int > best_error_iterations_
Definition: lstmtrainer.h:452
const int kErrorGraphInterval
Definition: lstmtrainer.cpp:56

◆ UpdateSubtrainer()

SubTrainerResult tesseract::LSTMTrainer::UpdateSubtrainer ( STRING log_msg)

Definition at line 577 of file lstmtrainer.cpp.

577  {
578  double training_error = CharError();
579  double sub_error = sub_trainer_->CharError();
580  double sub_margin = (training_error - sub_error) / sub_error;
581  if (sub_margin >= kSubTrainerMarginFraction) {
582  log_msg->add_str_double(" sub_trainer=", sub_error);
583  log_msg->add_str_double(" margin=", 100.0 * sub_margin);
584  *log_msg += "\n";
585  // Catch up to current iteration.
586  int end_iteration = training_iteration();
587  while (sub_trainer_->training_iteration() < end_iteration &&
588  sub_margin >= kSubTrainerMarginFraction) {
589  int target_iteration =
591  while (sub_trainer_->training_iteration() < target_iteration) {
592  sub_trainer_->TrainOnLine(this, false);
593  }
594  STRING batch_log = "Sub:";
595  sub_trainer_->PrepareLogMsg(&batch_log);
596  batch_log += "\n";
597  tprintf("UpdateSubtrainer:%s", batch_log.string());
598  *log_msg += batch_log;
599  sub_error = sub_trainer_->CharError();
600  sub_margin = (training_error - sub_error) / sub_error;
601  }
602  if (sub_error < best_error_rate_ &&
603  sub_margin >= kSubTrainerMarginFraction) {
604  // The sub_trainer_ has won the race to a new best. Switch to it.
605  GenericVector<char> updated_trainer;
606  SaveTrainingDump(LIGHT, sub_trainer_, &updated_trainer);
607  ReadTrainingDump(updated_trainer, this);
608  log_msg->add_str_int(" Sub trainer wins at iteration ",
610  *log_msg += "\n";
611  return STR_REPLACED;
612  }
613  return STR_UPDATED;
614  }
615  return STR_NONE;
616 }
void PrepareLogMsg(STRING *log_msg) const
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
const double kSubTrainerMarginFraction
Definition: lstmtrainer.cpp:50
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer)
LSTMTrainer * sub_trainer_
Definition: lstmtrainer.h:444
Definition: strngs.h:45
void add_str_double(const char *str, double number)
Definition: strngs.cpp:391
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
Definition: lstmtrainer.h:268
const int kNumPagesPerBatch
Definition: lstmtrainer.cpp:58
double CharError() const
Definition: lstmtrainer.h:139

Member Data Documentation

◆ align_win_

ScrollView* tesseract::LSTMTrainer::align_win_
protected

Definition at line 390 of file lstmtrainer.h.

◆ best_error_history_

GenericVector<double> tesseract::LSTMTrainer::best_error_history_
protected

Definition at line 451 of file lstmtrainer.h.

◆ best_error_iterations_

GenericVector<int> tesseract::LSTMTrainer::best_error_iterations_
protected

Definition at line 452 of file lstmtrainer.h.

◆ best_error_rate_

double tesseract::LSTMTrainer::best_error_rate_
protected

Definition at line 424 of file lstmtrainer.h.

◆ best_error_rates_

double tesseract::LSTMTrainer::best_error_rates_[ET_COUNT]
protected

Definition at line 426 of file lstmtrainer.h.

◆ best_iteration_

int tesseract::LSTMTrainer::best_iteration_
protected

Definition at line 428 of file lstmtrainer.h.

◆ best_model_data_

GenericVector<char> tesseract::LSTMTrainer::best_model_data_
protected

Definition at line 438 of file lstmtrainer.h.

◆ best_model_name_

STRING tesseract::LSTMTrainer::best_model_name_
protected

Definition at line 410 of file lstmtrainer.h.

◆ best_trainer_

GenericVector<char> tesseract::LSTMTrainer::best_trainer_
protected

Definition at line 441 of file lstmtrainer.h.

◆ checkpoint_iteration_

int tesseract::LSTMTrainer::checkpoint_iteration_
protected

Definition at line 400 of file lstmtrainer.h.

◆ checkpoint_name_

STRING tesseract::LSTMTrainer::checkpoint_name_
protected

Definition at line 404 of file lstmtrainer.h.

◆ checkpoint_reader_

CheckPointReader tesseract::LSTMTrainer::checkpoint_reader_
protected

Definition at line 418 of file lstmtrainer.h.

◆ checkpoint_writer_

CheckPointWriter tesseract::LSTMTrainer::checkpoint_writer_
protected

Definition at line 419 of file lstmtrainer.h.

◆ ctc_win_

ScrollView* tesseract::LSTMTrainer::ctc_win_
protected

Definition at line 394 of file lstmtrainer.h.

◆ debug_interval_

int tesseract::LSTMTrainer::debug_interval_
protected

Definition at line 398 of file lstmtrainer.h.

◆ error_buffers_

GenericVector<double> tesseract::LSTMTrainer::error_buffers_[ET_COUNT]
protected

Definition at line 473 of file lstmtrainer.h.

◆ error_rate_of_last_saved_best_

float tesseract::LSTMTrainer::error_rate_of_last_saved_best_
protected

Definition at line 446 of file lstmtrainer.h.

◆ error_rates_

double tesseract::LSTMTrainer::error_rates_[ET_COUNT]
protected

Definition at line 475 of file lstmtrainer.h.

◆ file_reader_

FileReader tesseract::LSTMTrainer::file_reader_
protected

Definition at line 414 of file lstmtrainer.h.

◆ file_writer_

FileWriter tesseract::LSTMTrainer::file_writer_
protected

Definition at line 415 of file lstmtrainer.h.

◆ improvement_steps_

int tesseract::LSTMTrainer::improvement_steps_
protected

Definition at line 454 of file lstmtrainer.h.

◆ kRollingBufferSize_

const int tesseract::LSTMTrainer::kRollingBufferSize_ = 1000
staticprotected

Definition at line 472 of file lstmtrainer.h.

◆ last_perfect_training_iteration_

int tesseract::LSTMTrainer::last_perfect_training_iteration_
protected

Definition at line 469 of file lstmtrainer.h.

◆ learning_iteration_

int tesseract::LSTMTrainer::learning_iteration_
protected

Definition at line 458 of file lstmtrainer.h.

◆ model_base_

STRING tesseract::LSTMTrainer::model_base_
protected

Definition at line 402 of file lstmtrainer.h.

◆ num_training_stages_

int tesseract::LSTMTrainer::num_training_stages_
protected

Definition at line 412 of file lstmtrainer.h.

◆ perfect_delay_

int tesseract::LSTMTrainer::perfect_delay_
protected

Definition at line 466 of file lstmtrainer.h.

◆ prev_sample_iteration_

int tesseract::LSTMTrainer::prev_sample_iteration_
protected

Definition at line 460 of file lstmtrainer.h.

◆ recon_win_

ScrollView* tesseract::LSTMTrainer::recon_win_
protected

Definition at line 396 of file lstmtrainer.h.

◆ serialize_amount_

SerializeAmount tesseract::LSTMTrainer::serialize_amount_
mutableprotected

Definition at line 408 of file lstmtrainer.h.

◆ stall_iteration_

int tesseract::LSTMTrainer::stall_iteration_
protected

Definition at line 436 of file lstmtrainer.h.

◆ sub_trainer_

LSTMTrainer* tesseract::LSTMTrainer::sub_trainer_
protected

Definition at line 444 of file lstmtrainer.h.

◆ target_win_

ScrollView* tesseract::LSTMTrainer::target_win_
protected

Definition at line 392 of file lstmtrainer.h.

◆ training_data_

DocumentCache tesseract::LSTMTrainer::training_data_
protected

Definition at line 406 of file lstmtrainer.h.

◆ training_stage_

int tesseract::LSTMTrainer::training_stage_
protected

Definition at line 448 of file lstmtrainer.h.

◆ worst_error_rate_

double tesseract::LSTMTrainer::worst_error_rate_
protected

Definition at line 430 of file lstmtrainer.h.

◆ worst_error_rates_

double tesseract::LSTMTrainer::worst_error_rates_[ET_COUNT]
protected

Definition at line 432 of file lstmtrainer.h.

◆ worst_iteration_

int tesseract::LSTMTrainer::worst_iteration_
protected

Definition at line 434 of file lstmtrainer.h.

◆ worst_model_data_

GenericVector<char> tesseract::LSTMTrainer::worst_model_data_
protected

Definition at line 439 of file lstmtrainer.h.


The documentation for this class was generated from the following files: