19 #ifndef TESSERACT_LSTM_LSTMTRAINER_H_ 20 #define TESSERACT_LSTM_LSTMTRAINER_H_ 94 CheckPointReader checkpoint_reader,
95 CheckPointWriter checkpoint_writer,
96 const char* model_base,
const char* checkpoint_name,
97 int debug_interval,
inT64 max_memory);
125 bool InitNetwork(
const STRING& network_spec,
int append_index,
int net_flags,
174 const ImageData* trainingdata,
int iteration,
double min_dict_ratio,
175 double dict_ratio_step,
double max_dict_ratio,
double min_cert_offset,
176 double cert_offset_step,
double max_cert_offset, STRING* results);
201 TestCallback tester, STRING* log_msg);
206 void LogIterations(
const char* intro_str, STRING* log_msg)
const;
243 LSTMTrainer* samples_trainer);
292 const LSTMTrainer* trainer,
386 TestCallback tester);
480 #endif // TESSERACT_LSTM_LSTMTRAINER_H_
void PrepareLogMsg(STRING *log_msg) const
double best_error_rate() const
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
const ImageData * GetPageBySerial(int serial)
int CurrentTrainingStage() const
GenericVector< char > best_trainer_
CheckPointWriter checkpoint_writer_
double worst_error_rates_[ET_COUNT]
double error_rates_[ET_COUNT]
void SetupCheckpointInfo()
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
void InitCharSet(const UNICHARSET &unicharset, const STRING &script_dir, int train_flags)
virtual bool Serialize(TFile *fp) const
double ComputeWinnerError(const NetworkIO &deltas)
GenericVector< double > error_buffers_[ET_COUNT]
bool(* FileReader)(const STRING &filename, GenericVector< char > *data)
SerializeAmount serialize_amount_
GenericVector< double > best_error_history_
const double * error_rates() const
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum)
DocumentCache * mutable_training_data()
const UNICHARSET & GetUnicharset() const
int best_iteration() const
static const int kRollingBufferSize_
void LogIterations(const char *intro_str, STRING *log_msg) const
bool TryLoadingCheckpoint(const char *filename)
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
DocumentCache training_data_
double ActivationError() const
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
int last_perfect_training_iteration_
void set_perfect_delay(int delay)
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
int sample_iteration() const
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer)
LSTMTrainer * sub_trainer_
float error_rate_of_last_saved_best_
virtual void ConvertToInt()
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
void StartSubtrainer(STRING *log_msg)
double NewSingleError(ErrorTypes type) const
const DocumentCache & training_data() const
TessResultCallback3< bool, SerializeAmount, const LSTMTrainer *, GenericVector< char > * > * CheckPointWriter
void SetSerializeMode(SerializeAmount serialize_amount) const
bool SaveBestModel(FileWriter writer) const
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
CheckPointReader checkpoint_reader_
int checkpoint_iteration_
int learning_iteration() const
int InitTensorFlowNetwork(const std::string &tf_proto)
void SetUnicharsetProperties(const STRING &script_dir)
STRING DumpFilename() const
void FillErrorBuffer(double new_error, ErrorTypes type)
double learning_rate() const
virtual bool DeSerialize(TFile *fp)
GenericVector< char > best_model_data_
double ComputeRMSError(const NetworkIO &deltas)
const GenericVector< char > & best_trainer() const
bool LoadAllTrainingData(const GenericVector< STRING > &filenames)
bool ReadSizedTrainingDump(const char *data, int size)
static LSTMRecognizer * ReadRecognitionDump(const GenericVector< char > &data)
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
bool(* FileWriter)(const GenericVector< char > &data, const STRING &filename)
double LastSingleError(ErrorTypes type) const
double best_error_rates_[ET_COUNT]
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
GenericVector< char > worst_model_data_
TessResultCallback2< bool, const GenericVector< char > &, LSTMTrainer * > * CheckPointReader
typedef int(ZCALLBACK *close_file_func) OF((voidpf opaque
GenericVector< int > best_error_iterations_
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
int prev_sample_iteration_
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
void UpdateErrorBuffer(double new_error, ErrorTypes type)
bool TransitionTrainingStage(float error_threshold)
int improvement_steps() const
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
int training_iteration() const
bool MaintainCheckpointsSpecific(int iteration, const GenericVector< char > *train_model, const GenericVector< char > *rec_model, TestCallback tester, STRING *log_msg)
bool SimpleTextOutput() const
TessResultCallback4< STRING, int, const double *, const GenericVector< char > &, int > * TestCallback
void SaveRecognitionDump(GenericVector< char > *data) const
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)