21 #include "config_auto.h" 27 #include "allheaders.h" 35 #ifdef INCLUDE_TENSORFLOW 89 const char* model_base,
const char* checkpoint_name,
90 int debug_interval,
inT64 max_memory)
128 tprintf(
"Loaded file %s, unpacking...\n", filename);
139 const STRING& script_dir,
int train_flags) {
176 int net_flags,
float weight_range,
193 tprintf(
"Built network:%s from request %s\n",
195 tprintf(
"Training parameters:\n Debug interval = %d," 196 " weights = %g, learning rate = %g, momentum=%g\n",
204 #ifdef INCLUDE_TENSORFLOW 206 TFNetwork* tf_net =
new TFNetwork(
"TensorFlow");
209 tprintf(
"InitFromProtoStr failed!!\n");
216 tprintf(
"TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
236 for (
int i = 0; i <
ET_COUNT; ++i) {
249 const ImageData* trainingdata,
int iteration,
double min_dict_ratio,
250 double dict_ratio_step,
double max_dict_ratio,
double min_cert_offset,
251 double cert_offset_step,
double max_cert_offset,
STRING* results) {
273 for (
double r = min_dict_ratio; r < max_dict_ratio; r += dict_ratio_step) {
274 for (
double c = min_cert_offset; c < max_cert_offset;
275 c += cert_offset_step) {
282 if ((r == min_dict_ratio && c == min_cert_offset) ||
283 !std::isfinite(word_error)) {
286 tprintf(
"r=%g, c=%g, truth=%s, ocr=%s, wderr=%g, truth[0]=%d\n", r, c,
340 *log_msg +=
UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
353 *log_msg +=
" failed to write best model:";
355 *log_msg +=
" wrote best model:";
358 *log_msg += best_model_name;
363 *log_msg +=
UpdateErrorGraph(iteration, error_rate, rec_model_data, tester);
367 *log_msg +=
"\nDivergence! ";
384 result = sub_trainer_result !=
STR_NONE;
392 *log_msg +=
" failed to write checkpoint.";
394 *log_msg +=
" wrote checkpoint.";
415 *log_msg += intro_str;
444 for (
int i = 0; i <
ET_COUNT; ++i) {
451 if (fp->
FWrite(&amount,
sizeof(amount), 1) != 1)
return false;
452 if (amount ==
LIGHT)
return true;
473 if (!sub_data.
Serialize(fp))
return false;
489 tprintf(
"Warning: LSTMTrainer deserialized an LSTMRecognizer!\n");
502 for (
int i = 0; i <
ET_COUNT; ++i) {
509 if (fp->
FRead(&amount,
sizeof(amount), 1) != 1)
return false;
510 if (amount ==
LIGHT)
return true;
531 if (sub_data.
empty()) {
551 *log_msg +=
" Failed to revert to previous best for trial!";
555 log_msg->
add_str_int(
" Trial sub_trainer_ from iteration ",
580 double sub_margin = (training_error - sub_error) / sub_error;
581 if (sub_margin >= kSubTrainerMarginFraction) {
589 int target_iteration =
594 STRING batch_log =
"Sub:";
598 *log_msg += batch_log;
600 sub_margin = (training_error - sub_error) / sub_error;
603 sub_margin >= kSubTrainerMarginFraction) {
608 log_msg->
add_str_int(
" Sub trainer wins at iteration ",
624 kLearningRateDecay, kNumAdjustmentIterations, samples_trainer);
625 log_msg->
add_str_int(
"\nReduced learning rate on layers: ", num_reduced);
647 const double kEpsilon = 1.0e-30;
649 int num_layers = layers.
size();
654 for (
int i = 0; i < LR_COUNT; ++i) {
658 double momentum_factor = 1.0 / (1.0 -
momentum_);
661 for (
int i = 0; i < num_layers; ++i) {
666 for (
int s = 0; s < num_samples; ++s) {
668 for (
int ww = 0; ww < LR_COUNT; ++ww) {
670 float ww_factor = momentum_factor;
671 if (ww == LR_DOWN) ww_factor *= factor;
678 for (
int i = 0; i < num_layers; ++i) {
679 if (num_weights[i] == 0)
continue;
687 if (trainingdata == NULL)
continue;
691 for (
int i = 0; i < num_layers; ++i) {
692 if (num_weights[i] == 0)
continue;
698 layer->
Update(0.0, kEpsilon, 0);
702 float before_bad = bad_sums[ww][i];
703 float before_ok = ok_sums[ww][i];
705 &ok_sums[ww][i], &bad_sums[ww][i]);
707 bad_sums[ww][i] + ok_sums[ww][i] - before_bad - before_ok;
709 bad_frac = (bad_sums[ww][i] - before_bad) / bad_frac;
715 for (
int i = 0; i < num_layers; ++i) {
716 if (num_weights[i] == 0)
continue;
719 double total_down = bad_sums[LR_DOWN][i] + ok_sums[LR_DOWN][i];
720 double total_same = bad_sums[LR_SAME][i] + ok_sums[LR_SAME][i];
721 double frac_down = bad_sums[LR_DOWN][i] / total_down;
722 double frac_same = bad_sums[LR_SAME][i] / total_same;
724 lr * factor, 100.0 * frac_down, lr, 100.0 * frac_same);
725 if (frac_down < frac_same * kImprovementFraction) {
733 if (num_lowered == 0) {
735 for (
int i = 0; i < num_layers; ++i) {
736 if (num_weights[i] > 0) {
752 tprintf(
"Empty truth string!\n");
758 if (!simple_text) labels->
push_back(null_char);
762 for (
int i = 0; i < internal_labels.
size(); ++i) {
763 if (recoder != NULL) {
768 for (
int j = 0; j < len; ++j) {
770 if (!simple_text) labels->
push_back(null_char);
779 if (!simple_text) labels->
push_back(null_char);
782 if (success)
return true;
784 tprintf(
"Encoding of string failed! Failure bytes:");
785 while (err_index < str.
length()) {
786 tprintf(
" %x", str[err_index++]);
815 #ifndef GRAPHICS_DISABLED 819 #endif // GRAPHICS_DISABLED 830 if (trainingdata == NULL) {
831 tprintf(
"Null trainingdata.\n");
839 tprintf(
"Can't encode transcription: %s\n",
844 while (w < truth_labels.
size() &&
847 if (w == truth_labels.
size()) {
848 tprintf(
"Blank transcription: %s\n",
855 if (!
RecognizeLine(*trainingdata, invert, debug, invert, 0.0f, &image_scale,
856 &inputs, fwd_outputs)) {
857 tprintf(
"Image not trainable\n");
864 tprintf(
"Compute simple targets failed!\n");
867 }
else if (loss_type ==
LT_CTC) {
869 tprintf(
"Compute CTC targets failed!\n");
873 tprintf(
"Logistic outputs not implemented yet!\n");
880 if (loss_type !=
LT_CTC) {
900 trainingdata->
page_number(), delta_error == 0.0 ?
"(Perfect)" :
"");
902 if (delta_error == 0.0)
return PERFECT;
983 tprintf(
"Setting unichar properties\n");
985 if (strcmp(
"NULL",
GetUnicharset().get_script_from_script_id(s)) == 0)
995 tprintf(
"Setting properties for script %s\n",
1005 STRING stroke_table = &data[0];
1017 tprintf(
"Failed to load radical-stroke info from: %s\n",
1034 if (truth_text.
string() == NULL || truth_text.
length() <= 0) {
1035 tprintf(
"Empty truth string at decode time!\n");
1044 tprintf(
"Iteration %d: ALIGNED TRUTH : %s\n",
1047 tprintf(
"TRAINING activation path for truth string %s\n",
1062 const char* window_name,
ScrollView** window) {
1063 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics. 1064 int width = targets.
Width();
1068 for (
int c = 0; c < num_features; ++c) {
1070 (*window)->Pen(static_cast<ScrollView::Color>(color));
1072 for (
int t = 0; t < width; ++t) {
1073 double target = targets.
f(t)[c];
1077 (*window)->SetCursor(t - 1, 0);
1080 (*window)->DrawTo(t, target);
1081 }
else if (start_t >= 0) {
1082 (*window)->DrawTo(t, 0);
1083 (*window)->DrawTo(start_t - 1, 0);
1088 (*window)->DrawTo(width, 0);
1089 (*window)->DrawTo(start_t - 1, 0);
1092 (*window)->Update();
1093 #endif // GRAPHICS_DISABLED 1101 if (truth_labels.
size() > targets->
Width()) {
1102 tprintf(
"Error: transcription %s too long to fit into target of width %d\n",
1106 for (
int i = 0; i < truth_labels.
size() && i < targets->
Width(); ++i) {
1109 for (
int i = truth_labels.
size(); i < targets->
Width(); ++i) {
1130 double char_error,
double word_error) {
1150 double total_error = 0.0;
1151 int width = deltas.
Width();
1153 for (
int t = 0; t < width; ++t) {
1154 const float* class_errs = deltas.
f(t);
1155 for (
int c = 0; c < num_classes; ++c) {
1156 double error = class_errs[c];
1157 total_error += error * error;
1160 return sqrt(total_error / (width * num_classes));
1170 int width = deltas.
Width();
1172 for (
int t = 0; t < width; ++t) {
1173 const float* class_errs = deltas.
f(t);
1174 for (
int c = 0; c < num_classes; ++c) {
1175 float abs_delta = fabs(class_errs[c]);
1178 if (0.5 <= abs_delta)
1182 return static_cast<double>(num_errors) / width;
1191 for (
int i = 0; i < truth_str.
size(); ++i) {
1193 ++label_counts[truth_str[i]];
1197 for (
int i = 0; i < ocr_str.
size(); ++i) {
1199 --label_counts[ocr_str[i]];
1202 int char_errors = 0;
1203 for (
int i = 0; i < label_counts.
size(); ++i) {
1204 char_errors += abs(label_counts[i]);
1206 if (truth_size == 0) {
1207 return (char_errors == 0) ? 0.0 : 1.0;
1209 return static_cast<double>(char_errors) / truth_size;
1215 typedef std::unordered_map<std::string, int, std::hash<std::string> > StrMap;
1217 truth_str->
split(
' ', &truth_words);
1218 if (truth_words.
empty())
return 0.0;
1219 ocr_str->
split(
' ', &ocr_words);
1221 for (
int i = 0; i < truth_words.
size(); ++i) {
1222 std::string truth_word(truth_words[i].
string());
1223 StrMap::iterator it = word_counts.find(truth_word);
1224 if (it == word_counts.end())
1225 word_counts.insert(std::make_pair(truth_word, 1));
1229 for (
int i = 0; i < ocr_words.
size(); ++i) {
1230 std::string ocr_word(ocr_words[i].
string());
1231 StrMap::iterator it = word_counts.find(ocr_word);
1232 if (it == word_counts.end())
1233 word_counts.insert(std::make_pair(ocr_word, -1));
1237 int word_recall_errs = 0;
1238 for (StrMap::const_iterator it = word_counts.begin(); it != word_counts.end();
1240 if (it->second > 0) word_recall_errs += it->second;
1242 return static_cast<double>(word_recall_errs) / truth_words.
size();
1252 double buffer_sum = 0.0;
1253 for (
int i = 0; i < mean_count; ++i) buffer_sum +=
error_buffers_[type][i];
1254 double mean = buffer_sum / mean_count;
1268 tprintf(
"Mean rms=%g%%, delta=%g%%, train=%g%%(%g%%), skip ratio=%g%%\n",
1300 if (tester != NULL) {
1312 double two_percent_more = error_rate + 2.0;
1319 tprintf(
"2 Percent improvement time=%d, best error was %g @ %d\n",
1324 if (tester != NULL) {
bool DeSerialize(bool swap, FILE *fp)
void PrepareLogMsg(STRING *log_msg) const
bool ComputeCTCTargets(const GenericVector< int > &truth_labels, NetworkIO *outputs, NetworkIO *targets)
void add_str_int(const char *str, int number)
float GetLayerLearningRate(const STRING &id) const
int CurrentTrainingStage() const
GenericVector< char > best_trainer_
int EncodeUnichar(int unichar_id, RecodedCharID *code) const
CheckPointWriter checkpoint_writer_
bool AnySuspiciousTruth(float confidence_thr) const
double worst_error_rates_[ET_COUNT]
_ConstTessMemberResultCallback_0_0< false, R, T1 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)() const)
double error_rates_[ET_COUNT]
bool SaveDataToFile(const GenericVector< char > &data, const STRING &filename)
void RecognizeLine(const ImageData &image_data, bool invert, bool debug, double worst_dict_cert, bool use_alternates, const UNICHARSET *target_unicharset, const TBOX &line_box, float score_ratio, bool one_word, PointerVector< WERD_RES > *words)
virtual void Update(float learning_rate, float momentum, int num_samples)
void init_to_size(int size, T t)
void ReduceLearningRates(LSTMTrainer *samples_trainer, STRING *log_msg)
void InitCharSet(const UNICHARSET &unicharset, const STRING &script_dir, int train_flags)
virtual bool Serialize(TFile *fp) const
static const float kMinCertainty
LossType OutputLossType() const
double ComputeWinnerError(const NetworkIO &deltas)
GenericVector< double > error_buffers_[ET_COUNT]
bool(* FileReader)(const STRING &filename, GenericVector< char > *data)
SerializeAmount serialize_amount_
GenericVector< double > best_error_history_
const int kMinStallIterations
const double kMinDivergenceRate
void SubtractAllFromFloat(const NetworkIO &src)
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum)
void SetActivations(int t, int label, float ok_score)
const char * get_script_from_script_id(int id) const
const double kHighConfidence
const int kNumAdjustmentIterations
const UNICHARSET & GetUnicharset() const
void SetPropertiesFromOther(const UNICHARSET &src)
NetworkScratch scratch_space_
static const int kRollingBufferSize_
const char * string() const
void Resize(const NetworkIO &src, int num_features)
bool DeSerialize(TFile *fp)
void ScaleLearningRate(double factor)
int FReadEndian(void *buffer, int size, int count)
void LogIterations(const char *intro_str, STRING *log_msg) const
bool TryLoadingCheckpoint(const char *filename)
GenericVector< STRING > EnumerateLayers() const
void ExtractBestPathAsLabels(GenericVector< int > *labels, GenericVector< int > *xcoords) const
const double kImprovementFraction
int IntCastRounded(double x)
const double kSubTrainerMarginFraction
void DisplayTargets(const NetworkIO &targets, const char *window_name, ScrollView **window)
DocumentCache training_data_
const STRING & transcription() const
void OpenWrite(GenericVector< char > *data)
const double kStageTransitionThreshold
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
int last_perfect_training_iteration_
LIST search(LIST list, void *key, int_compare is_equal)
const double kLearningRateDecay
bool ComputeEncoding(const UNICHARSET &unicharset, int null_id, STRING *radical_stroke_table)
bool SaveTrainingDump(SerializeAmount serialize_amount, const LSTMTrainer *trainer, GenericVector< char > *data) const
CachingStrategy CacheStrategy() const
SubTrainerResult UpdateSubtrainer(STRING *log_msg)
int sample_iteration() const
bool DebugLSTMTraining(const NetworkIO &inputs, const ImageData &trainingdata, const NetworkIO &fwd_outputs, const GenericVector< int > &truth_labels, const NetworkIO &outputs)
bool ReadTrainingDump(const GenericVector< char > &data, LSTMTrainer *trainer)
LSTMTrainer * sub_trainer_
int get_script_table_size() const
float error_rate_of_last_saved_best_
double ComputeErrorRates(const NetworkIO &deltas, double char_error, double word_error)
virtual void DebugWeights()
void StartSubtrainer(STRING *log_msg)
bool has_special_codes() const
double NewSingleError(ErrorTypes type) const
void add_str_double(const char *str, double number)
bool LoadDataFromFile(const char *filename, GenericVector< char > *data)
static void NormalizeProbs(NetworkIO *probs)
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
static bool InitNetwork(int num_outputs, STRING network_spec, int append_index, int net_flags, float weight_range, TRand *randomizer, Network **network)
bool TestFlag(NetworkFlags flag) const
void LabelsFromOutputs(const NetworkIO &outputs, float null_thr, GenericVector< int > *labels, GenericVector< int > *xcoords)
Network * GetLayer(const STRING &id) const
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
CheckPointReader checkpoint_reader_
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
int checkpoint_iteration_
int FWrite(const void *buffer, int size, int count)
bool LoadDocuments(const GenericVector< STRING > &filenames, CachingStrategy cache_strategy, FileReader reader)
int learning_iteration() const
int InitTensorFlowNetwork(const std::string &tf_proto)
void SetUnicharsetProperties(const STRING &script_dir)
STRING DumpFilename() const
void FillErrorBuffer(double new_error, ErrorTypes type)
double learning_rate() const
virtual bool DeSerialize(TFile *fp)
GenericVector< char > best_model_data_
double ComputeRMSError(const NetworkIO &deltas)
bool LoadAllTrainingData(const GenericVector< STRING > &filenames)
virtual R Run(A1, A2, A3)=0
bool ReadSizedTrainingDump(const char *data, int size)
const GenericVector< TBOX > & boxes() const
const STRING & imagefilename() const
void Decode(const NetworkIO &output, double dict_ratio, double cert_offset, double worst_dict_cert, const UNICHARSET *charset)
static LSTMRecognizer * ReadRecognitionDump(const GenericVector< char > &data)
const int kNumPagesPerBatch
bool encode_string(const char *str, bool give_up_on_failure, GenericVector< UNICHAR_ID > *encoding, GenericVector< char > *lengths, int *encoded_length) const
void DisplayForward(const NetworkIO &inputs, const GenericVector< int > &labels, const GenericVector< int > &label_coords, const char *window_name, ScrollView **window)
double ComputeCharError(const GenericVector< int > &truth_str, const GenericVector< int > &ocr_str)
Trainability PrepareForBackward(const ImageData *trainingdata, NetworkIO *fwd_outputs, NetworkIO *targets)
void DebugActivationPath(const NetworkIO &outputs, const GenericVector< int > &labels, const GenericVector< int > &xcoords)
bool(* FileWriter)(const GenericVector< char > &data, const STRING &filename)
double best_error_rates_[ET_COUNT]
STRING UpdateErrorGraph(int iteration, double error_rate, const GenericVector< char > &model_data, TestCallback tester)
GenericVector< char > worst_model_data_
GenericVector< int > best_error_iterations_
void SetIteration(int iteration)
const GENERIC_2D_ARRAY< float > & float_array() const
bool Open(const STRING &filename, FileReader reader)
double ComputeWordError(STRING *truth_str, STRING *ocr_str)
bool Serialize(FILE *fp) const
const STRING & name() const
int prev_sample_iteration_
const int kMinStartedErrorRate
bool ComputeTextTargets(const NetworkIO &outputs, const GenericVector< int > &truth_labels, NetworkIO *targets)
virtual R Run(A1, A2, A3, A4)=0
void UpdateErrorBuffer(double new_error, ErrorTypes type)
const int kErrorGraphInterval
virtual STRING spec() const
bool TransitionTrainingStage(float error_threshold)
STRING DecodeLabels(const GenericVector< int > &labels)
bool EncodeString(const STRING &str, GenericVector< int > *labels) const
virtual void CountAlternators(const Network &other, double *same, double *changed) const
int ReduceLayerLearningRates(double factor, int num_samples, LSTMTrainer *samples_trainer)
int training_iteration() const
static bool ComputeCTCTargets(const GenericVector< int > &truth_labels, int null_char, const GENERIC_2D_ARRAY< float > &outputs, NetworkIO *targets)
inT32 training_iteration_
bool load_from_inmemory_file(const char *const memory, int mem_size, bool skip_fragments)
bool SimpleTextOutput() const
bool Serialize(TFile *fp) const
const double kBestCheckpointFraction
virtual void SetEnableTraining(TrainingState state)
void split(const char c, GenericVector< STRING > *splited)
void CopyFrom(const UNICHARSET &src)
SVEvent * AwaitEvent(SVEventType type)
int FRead(void *buffer, int size, int count)
void SaveRecognitionDump(GenericVector< char > *data) const
void ScaleLayerLearningRate(const STRING &id, double factor)
Trainability GridSearchDictParams(const ImageData *trainingdata, int iteration, double min_dict_ratio, double dict_ratio_step, double max_dict_ratio, double min_cert_offset, double cert_offset_step, double max_cert_offset, STRING *results)