tesseract  4.00.00dev
lstmtraining.cpp File Reference
#include "base/commandlineflags.h"
#include "commontraining.h"
#include "lstmtester.h"
#include "lstmtrainer.h"
#include "params.h"
#include "strngs.h"
#include "tprintf.h"
#include "unicharset_training_utils.h"

Go to the source code of this file.

Functions

 INT_PARAM_FLAG (debug_interval, 0, "How often to display the alignment.")
 
 STRING_PARAM_FLAG (net_spec, "", "Network specification")
 
 INT_PARAM_FLAG (train_mode, 80, "Controls gross training behavior.")
 
 INT_PARAM_FLAG (net_mode, 192, "Controls network behavior.")
 
 INT_PARAM_FLAG (perfect_sample_delay, 4, "How many imperfect samples between perfect ones.")
 
 DOUBLE_PARAM_FLAG (target_error_rate, 0.01, "Final error rate in percent.")
 
 DOUBLE_PARAM_FLAG (weight_range, 0.1, "Range of initial random weights.")
 
 DOUBLE_PARAM_FLAG (learning_rate, 1.0e-4, "Weight factor for new deltas.")
 
 DOUBLE_PARAM_FLAG (momentum, 0.9, "Decay factor for repeating deltas.")
 
 INT_PARAM_FLAG (max_image_MB, 6000, "Max memory to use for images.")
 
 STRING_PARAM_FLAG (continue_from, "", "Existing model to extend")
 
 STRING_PARAM_FLAG (model_output, "lstmtrain", "Basename for output models")
 
 STRING_PARAM_FLAG (script_dir, "", "Required to set unicharset properties or" " use unicharset compression.")
 
 STRING_PARAM_FLAG (train_listfile, "", "File listing training files in lstmf training format.")
 
 STRING_PARAM_FLAG (eval_listfile, "", "File listing eval files in lstmf training format.")
 
 BOOL_PARAM_FLAG (stop_training, false, "Just convert the training model to a runtime model.")
 
 INT_PARAM_FLAG (append_index, -1, "Index in continue_from Network at which to" " attach the new network defined by net_spec")
 
 BOOL_PARAM_FLAG (debug_network, false, "Get info on distribution of weight values")
 
 INT_PARAM_FLAG (max_iterations, 0, "If set, exit after this many iterations")
 
 DECLARE_STRING_PARAM_FLAG (U)
 
int main (int argc, char **argv)
 

Variables

const int kNumPagesPerBatch = 100
 

Function Documentation

◆ BOOL_PARAM_FLAG() [1/2]

BOOL_PARAM_FLAG ( stop_training  ,
false  ,
"Just convert the training model to a runtime model."   
)

◆ BOOL_PARAM_FLAG() [2/2]

BOOL_PARAM_FLAG ( debug_network  ,
false  ,
"Get info on distribution of weight values"   
)

◆ DECLARE_STRING_PARAM_FLAG()

DECLARE_STRING_PARAM_FLAG ( )

◆ DOUBLE_PARAM_FLAG() [1/4]

DOUBLE_PARAM_FLAG ( target_error_rate  ,
0.  01,
"Final error rate in percent."   
)

◆ DOUBLE_PARAM_FLAG() [2/4]

DOUBLE_PARAM_FLAG ( weight_range  ,
0.  1,
"Range of initial random weights."   
)

◆ DOUBLE_PARAM_FLAG() [3/4]

DOUBLE_PARAM_FLAG ( learning_rate  ,
1.0e-  4,
"Weight factor for new deltas."   
)

◆ DOUBLE_PARAM_FLAG() [4/4]

DOUBLE_PARAM_FLAG ( momentum  ,
0.  9,
"Decay factor for repeating deltas."   
)

◆ INT_PARAM_FLAG() [1/7]

INT_PARAM_FLAG ( debug_interval  ,
,
"How often to display the alignment."   
)

◆ INT_PARAM_FLAG() [2/7]

INT_PARAM_FLAG ( train_mode  ,
80  ,
"Controls gross training behavior."   
)

◆ INT_PARAM_FLAG() [3/7]

INT_PARAM_FLAG ( net_mode  ,
192  ,
"Controls network behavior."   
)

◆ INT_PARAM_FLAG() [4/7]

INT_PARAM_FLAG ( perfect_sample_delay  ,
,
"How many imperfect samples between perfect ones."   
)

◆ INT_PARAM_FLAG() [5/7]

INT_PARAM_FLAG ( max_image_MB  ,
6000  ,
"Max memory to use for images."   
)

◆ INT_PARAM_FLAG() [6/7]

INT_PARAM_FLAG ( append_index  ,
1,
"Index in continue_from Network at which to" " attach the new network defined by net_spec"   
)

◆ INT_PARAM_FLAG() [7/7]

INT_PARAM_FLAG ( max_iterations  ,
,
"If  set,
exit after this many iterations"   
)

◆ main()

int main ( int  argc,
char **  argv 
)

This program reads in a text file consisting of feature samples from a training page in the following format:

   FontName UTF8-char-str xmin ymin xmax ymax page-number
    NumberOfFeatureTypes(N)
      FeatureTypeName1 NumberOfFeatures(M)
         Feature1
         ...
         FeatureM
      FeatureTypeName2 NumberOfFeatures(M)
         Feature1
         ...
         FeatureM
      ...
      FeatureTypeNameN NumberOfFeatures(M)
         Feature1
         ...
         FeatureM
   FontName CharName ...

The result of this program is a binary inttemp file used by the OCR engine.

Parameters
argcnumber of command line arguments
argvarray of command line arguments
Returns
none
Note
Exceptions: none
History: Fri Aug 18 08:56:17 1989, DSJ, Created.
History: Mon May 18 1998, Christy Russson, Revistion started.

Definition at line 66 of file lstmtraining.cpp.

66  {
67  ParseArguments(&argc, &argv);
68  // Purify the model name in case it is based on the network string.
69  if (FLAGS_model_output.empty()) {
70  tprintf("Must provide a --model_output!\n");
71  return 1;
72  }
73  STRING model_output = FLAGS_model_output.c_str();
74  for (int i = 0; i < model_output.length(); ++i) {
75  if (model_output[i] == '[' || model_output[i] == ']')
76  model_output[i] = '-';
77  if (model_output[i] == '(' || model_output[i] == ')')
78  model_output[i] = '_';
79  }
80  // Setup the trainer.
81  STRING checkpoint_file = FLAGS_model_output.c_str();
82  checkpoint_file += "_checkpoint";
83  STRING checkpoint_bak = checkpoint_file + ".bak";
84  tesseract::LSTMTrainer trainer(
85  nullptr, nullptr, nullptr, nullptr, FLAGS_model_output.c_str(),
86  checkpoint_file.c_str(), FLAGS_debug_interval,
87  static_cast<inT64>(FLAGS_max_image_MB) * 1048576);
88 
89  // Reading something from an existing model doesn't require many flags,
90  // so do it now and exit.
91  if (FLAGS_stop_training || FLAGS_debug_network) {
92  if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str())) {
93  tprintf("Failed to read continue from: %s\n",
94  FLAGS_continue_from.c_str());
95  return 1;
96  }
97  if (FLAGS_debug_network) {
98  trainer.DebugNetwork();
99  } else {
100  if (FLAGS_train_mode & tesseract::TF_INT_MODE)
101  trainer.ConvertToInt();
102  GenericVector<char> recognizer_data;
103  trainer.SaveRecognitionDump(&recognizer_data);
104  if (!tesseract::SaveDataToFile(recognizer_data,
105  FLAGS_model_output.c_str())) {
106  tprintf("Failed to write recognition model : %s\n",
107  FLAGS_model_output.c_str());
108  }
109  }
110  return 0;
111  }
112 
113  // Get the list of files to process.
114  if (FLAGS_train_listfile.empty()) {
115  tprintf("Must supply a list of training filenames! --train_listfile\n");
116  return 1;
117  }
118  GenericVector<STRING> filenames;
119  if (!tesseract::LoadFileLinesToStrings(FLAGS_train_listfile.c_str(),
120  &filenames)) {
121  tprintf("Failed to load list of training filenames from %s\n",
122  FLAGS_train_listfile.c_str());
123  return 1;
124  }
125 
126  UNICHARSET unicharset;
127  // Checkpoints always take priority if they are available.
128  if (trainer.TryLoadingCheckpoint(checkpoint_file.string()) ||
129  trainer.TryLoadingCheckpoint(checkpoint_bak.string())) {
130  tprintf("Successfully restored trainer from %s\n",
131  checkpoint_file.string());
132  } else {
133  if (!FLAGS_continue_from.empty()) {
134  // Load a past model file to improve upon.
135  if (!trainer.TryLoadingCheckpoint(FLAGS_continue_from.c_str())) {
136  tprintf("Failed to continue from: %s\n", FLAGS_continue_from.c_str());
137  return 1;
138  }
139  tprintf("Continuing from %s\n", FLAGS_continue_from.c_str());
140  trainer.InitIterations();
141  }
142  if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
143  // We need a unicharset to start from scratch or append.
144  string unicharset_str;
145  // Character coding to be used by the classifier.
146  if (!unicharset.load_from_file(FLAGS_U.c_str())) {
147  tprintf("Error: must provide a -U unicharset!\n");
148  return 1;
149  }
150  tesseract::SetupBasicProperties(true, &unicharset);
151  if (FLAGS_append_index >= 0) {
152  tprintf("Appending a new network to an old one!!");
153  if (FLAGS_continue_from.empty()) {
154  tprintf("Must set --continue_from for appending!\n");
155  return 1;
156  }
157  }
158  // We are initializing from scratch.
159  trainer.InitCharSet(unicharset, FLAGS_script_dir.c_str(),
160  FLAGS_train_mode);
161  if (!trainer.InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
162  FLAGS_net_mode, FLAGS_weight_range,
163  FLAGS_learning_rate, FLAGS_momentum)) {
164  tprintf("Failed to create network from spec: %s\n",
165  FLAGS_net_spec.c_str());
166  return 1;
167  }
168  trainer.set_perfect_delay(FLAGS_perfect_sample_delay);
169  }
170  }
171  if (!trainer.LoadAllTrainingData(filenames)) {
172  tprintf("Load of images failed!!\n");
173  return 1;
174  }
175 
176  tesseract::LSTMTester tester(static_cast<inT64>(FLAGS_max_image_MB) *
177  1048576);
178  tesseract::TestCallback tester_callback = nullptr;
179  if (!FLAGS_eval_listfile.empty()) {
180  if (!tester.LoadAllEvalData(FLAGS_eval_listfile.c_str())) {
181  tprintf("Failed to load eval data from: %s\n",
182  FLAGS_eval_listfile.c_str());
183  return 1;
184  }
185  tester_callback =
187  }
188  do {
189  // Train a few.
190  int iteration = trainer.training_iteration();
191  for (int target_iteration = iteration + kNumPagesPerBatch;
192  iteration < target_iteration;
193  iteration = trainer.training_iteration()) {
194  trainer.TrainOnLine(&trainer, false);
195  }
196  STRING log_str;
197  trainer.MaintainCheckpoints(tester_callback, &log_str);
198  tprintf("%s\n", log_str.string());
199  } while (trainer.best_error_rate() > FLAGS_target_error_rate &&
200  (trainer.training_iteration() < FLAGS_max_iterations ||
201  FLAGS_max_iterations == 0));
202  delete tester_callback;
203  tprintf("Finished! Error rate = %g\n", trainer.best_error_rate());
204  return 0;
205 } /* main */
const int kNumPagesPerBatch
void SetupBasicProperties(bool report_errors, bool decompose, UNICHARSET *unicharset)
int64_t inT64
Definition: host.h:40
_ConstTessMemberResultCallback_0_0< false, R, T1 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)() const)
Definition: tesscallback.h:116
bool SaveDataToFile(const GenericVector< char > &data, const STRING &filename)
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
inT32 length() const
Definition: strngs.cpp:193
void ParseArguments(int *argc, char ***argv)
Definition: strngs.h:45
bool LoadFileLinesToStrings(const STRING &filename, GenericVector< STRING > *lines)
bool load_from_file(const char *const filename, bool skip_fragments)
Definition: unicharset.h:348
const char * c_str() const
Definition: strngs.cpp:209
STRING RunEvalAsync(int iteration, const double *training_errors, const GenericVector< char > &model_data, int training_stage)
Definition: lstmtester.cpp:52

◆ STRING_PARAM_FLAG() [1/6]

STRING_PARAM_FLAG ( net_spec  ,
""  ,
"Network specification"   
)

◆ STRING_PARAM_FLAG() [2/6]

STRING_PARAM_FLAG ( continue_from  ,
""  ,
"Existing model to extend"   
)

◆ STRING_PARAM_FLAG() [3/6]

STRING_PARAM_FLAG ( model_output  ,
"lstmtrain"  ,
"Basename for output models"   
)

◆ STRING_PARAM_FLAG() [4/6]

STRING_PARAM_FLAG ( script_dir  ,
""  ,
"Required to set unicharset properties or" " use unicharset compression."   
)

◆ STRING_PARAM_FLAG() [5/6]

STRING_PARAM_FLAG ( train_listfile  ,
""  ,
"File listing training files in lstmf training format."   
)

◆ STRING_PARAM_FLAG() [6/6]

STRING_PARAM_FLAG ( eval_listfile  ,
""  ,
"File listing eval files in lstmf training format."   
)

Variable Documentation

◆ kNumPagesPerBatch

const int kNumPagesPerBatch = 100

Definition at line 60 of file lstmtraining.cpp.