19 #ifndef USE_STD_NAMESPACE 20 #include "base/commandlineflags.h" 30 INT_PARAM_FLAG(debug_interval, 0,
"How often to display the alignment.");
32 INT_PARAM_FLAG(train_mode, 80,
"Controls gross training behavior.");
35 "How many imperfect samples between perfect ones.");
40 INT_PARAM_FLAG(max_image_MB, 6000,
"Max memory to use for images.");
44 "Required to set unicharset properties or" 45 " use unicharset compression.");
47 "File listing training files in lstmf training format.");
49 "File listing eval files in lstmf training format.");
51 "Just convert the training model to a runtime model.");
52 INT_PARAM_FLAG(append_index, -1,
"Index in continue_from Network at which to" 53 " attach the new network defined by net_spec");
55 "Get info on distribution of weight values");
56 INT_PARAM_FLAG(max_iterations, 0,
"If set, exit after this many iterations");
66 int main(
int argc,
char **argv) {
69 if (FLAGS_model_output.empty()) {
70 tprintf(
"Must provide a --model_output!\n");
74 for (
int i = 0; i < model_output.
length(); ++i) {
75 if (model_output[i] ==
'[' || model_output[i] ==
']')
76 model_output[i] =
'-';
77 if (model_output[i] ==
'(' || model_output[i] ==
')')
78 model_output[i] =
'_';
81 STRING checkpoint_file = FLAGS_model_output.
c_str();
82 checkpoint_file +=
"_checkpoint";
83 STRING checkpoint_bak = checkpoint_file +
".bak";
85 nullptr,
nullptr,
nullptr,
nullptr, FLAGS_model_output.c_str(),
86 checkpoint_file.
c_str(), FLAGS_debug_interval,
87 static_cast<inT64>(FLAGS_max_image_MB) * 1048576);
91 if (FLAGS_stop_training || FLAGS_debug_network) {
93 tprintf(
"Failed to read continue from: %s\n",
94 FLAGS_continue_from.c_str());
97 if (FLAGS_debug_network) {
105 FLAGS_model_output.c_str())) {
106 tprintf(
"Failed to write recognition model : %s\n",
107 FLAGS_model_output.c_str());
114 if (FLAGS_train_listfile.empty()) {
115 tprintf(
"Must supply a list of training filenames! --train_listfile\n");
121 tprintf(
"Failed to load list of training filenames from %s\n",
122 FLAGS_train_listfile.c_str());
130 tprintf(
"Successfully restored trainer from %s\n",
131 checkpoint_file.
string());
133 if (!FLAGS_continue_from.empty()) {
136 tprintf(
"Failed to continue from: %s\n", FLAGS_continue_from.c_str());
139 tprintf(
"Continuing from %s\n", FLAGS_continue_from.c_str());
142 if (FLAGS_continue_from.empty() || FLAGS_append_index >= 0) {
144 string unicharset_str;
147 tprintf(
"Error: must provide a -U unicharset!\n");
151 if (FLAGS_append_index >= 0) {
152 tprintf(
"Appending a new network to an old one!!");
153 if (FLAGS_continue_from.empty()) {
154 tprintf(
"Must set --continue_from for appending!\n");
159 trainer.
InitCharSet(unicharset, FLAGS_script_dir.c_str(),
161 if (!trainer.
InitNetwork(FLAGS_net_spec.c_str(), FLAGS_append_index,
162 FLAGS_net_mode, FLAGS_weight_range,
163 FLAGS_learning_rate, FLAGS_momentum)) {
164 tprintf(
"Failed to create network from spec: %s\n",
165 FLAGS_net_spec.c_str());
172 tprintf(
"Load of images failed!!\n");
179 if (!FLAGS_eval_listfile.empty()) {
181 tprintf(
"Failed to load eval data from: %s\n",
182 FLAGS_eval_listfile.c_str());
192 iteration < target_iteration;
201 FLAGS_max_iterations == 0));
202 delete tester_callback;
double best_error_rate() const
const int kNumPagesPerBatch
void SetupBasicProperties(bool report_errors, bool decompose, UNICHARSET *unicharset)
_ConstTessMemberResultCallback_0_0< false, R, T1 >::base * NewPermanentTessCallback(const T1 *obj, R(T2::*member)() const)
bool SaveDataToFile(const GenericVector< char > &data, const STRING &filename)
void InitCharSet(const UNICHARSET &unicharset, const STRING &script_dir, int train_flags)
DECLARE_STRING_PARAM_FLAG(U)
bool InitNetwork(const STRING &network_spec, int append_index, int net_flags, float weight_range, float learning_rate, float momentum)
const char * string() const
bool TryLoadingCheckpoint(const char *filename)
void ParseArguments(int *argc, char ***argv)
DOUBLE_PARAM_FLAG(target_error_rate, 0.01, "Final error rate in percent.")
bool MaintainCheckpoints(TestCallback tester, STRING *log_msg)
void set_perfect_delay(int delay)
BOOL_PARAM_FLAG(stop_training, false, "Just convert the training model to a runtime model.")
INT_PARAM_FLAG(debug_interval, 0, "How often to display the alignment.")
STRING_PARAM_FLAG(net_spec, "", "Network specification")
bool LoadFileLinesToStrings(const STRING &filename, GenericVector< STRING > *lines)
const ImageData * TrainOnLine(LSTMTrainer *samples_trainer, bool batch)
bool load_from_file(const char *const filename, bool skip_fragments)
bool LoadAllTrainingData(const GenericVector< STRING > &filenames)
bool LoadAllEvalData(const STRING &filenames_file)
int main(int argc, char **argv)
const char * c_str() const
int training_iteration() const
STRING RunEvalAsync(int iteration, const double *training_errors, const GenericVector< char > &model_data, int training_stage)
void SaveRecognitionDump(GenericVector< char > *data) const