tesseract  4.00.00dev
network.cpp
Go to the documentation of this file.
1 // File: network.cpp
3 // Description: Base class for neural network implementations.
4 // Author: Ray Smith
5 // Created: Wed May 01 17:25:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 // Include automatically generated configuration file if running autoconf.
20 #ifdef HAVE_CONFIG_H
21 #include "config_auto.h"
22 #endif
23 
24 #include "network.h"
25 
26 #include <stdlib.h>
27 
28 // This base class needs to know about all its sub-classes because of the
29 // factory deserializing method: CreateFromFile.
30 #include "allheaders.h"
31 #include "convolve.h"
32 #include "fullyconnected.h"
33 #include "input.h"
34 #include "lstm.h"
35 #include "maxpool.h"
36 #include "parallel.h"
37 #include "reconfig.h"
38 #include "reversed.h"
39 #include "scrollview.h"
40 #include "series.h"
41 #include "statistc.h"
42 #ifdef INCLUDE_TENSORFLOW
43 #include "tfnetwork.h"
44 #endif
45 #include "tprintf.h"
46 
47 namespace tesseract {
48 
49 // Min and max window sizes.
50 const int kMinWinSize = 500;
51 const int kMaxWinSize = 2000;
52 // Window frame sizes need adding on to make the content fit.
53 const int kXWinFrameSize = 30;
54 const int kYWinFrameSize = 80;
55 
56 // String names corresponding to the NetworkType enum. Keep in sync.
57 // Names used in Serialization to allow re-ordering/addition/deletion of
58 // layer types in NetworkType without invalidating existing network files.
59 char const* const Network::kTypeNames[NT_COUNT] = {
60  "Invalid", "Input",
61  "Convolve", "Maxpool",
62  "Parallel", "Replicated",
63  "ParBidiLSTM", "DepParUDLSTM",
64  "Par2dLSTM", "Series",
65  "Reconfig", "RTLReversed",
66  "TTBReversed", "XYTranspose",
67  "LSTM", "SummLSTM",
68  "Logistic", "LinLogistic",
69  "LinTanh", "Tanh",
70  "Relu", "Linear",
71  "Softmax", "SoftmaxNoCTC",
72  "LSTMSoftmax", "LSTMBinarySoftmax",
73  "TensorFlow",
74 };
75 
77  : type_(NT_NONE),
78  training_(TS_ENABLED),
79  needs_to_backprop_(true),
80  network_flags_(0),
81  ni_(0),
82  no_(0),
83  num_weights_(0),
84  forward_win_(NULL),
85  backward_win_(NULL),
86  randomizer_(NULL) {}
87 Network::Network(NetworkType type, const STRING& name, int ni, int no)
88  : type_(type),
90  needs_to_backprop_(true),
91  network_flags_(0),
92  ni_(ni),
93  no_(no),
94  num_weights_(0),
95  name_(name),
96  forward_win_(NULL),
97  backward_win_(NULL),
98  randomizer_(NULL) {}
99 
101 }
102 
103 // Suspends/Enables/Permanently disables training by setting the training_
104 // flag. Serialize and DeSerialize only operate on the run-time data if state
105 // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
106 // temporarily disable layers in state TS_ENABLED, allowing a trainer to
107 // serialize as if it were a recognizer.
108 // TS_RE_ENABLE will re-enable layers that were previously in any disabled
109 // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
110 // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
111 // recognizer can be converted back to a trainer.
113  if (state == TS_RE_ENABLE) {
114  // Enable only from temp disabled.
116  } else if (state == TS_TEMP_DISABLE) {
117  // Temp disable only from enabled.
118  if (training_ == TS_ENABLED) training_ = state;
119  } else {
120  training_ = state;
121  }
122 }
123 
124 // Sets flags that control the action of the network. See NetworkFlags enum
125 // for bit values.
127  network_flags_ = flags;
128 }
129 
130 // Sets up the network for training. Initializes weights using weights of
131 // scale `range` picked according to the random number generator `randomizer`.
132 int Network::InitWeights(float range, TRand* randomizer) {
133  randomizer_ = randomizer;
134  return 0;
135 }
136 
137 // Provides a pointer to a TRand for any networks that care to use it.
138 // Note that randomizer is a borrowed pointer that should outlive the network
139 // and should not be deleted by any of the networks.
140 void Network::SetRandomizer(TRand* randomizer) {
141  randomizer_ = randomizer;
142 }
143 
144 // Sets needs_to_backprop_ to needs_backprop and returns true if
145 // needs_backprop || any weights in this network so the next layer forward
146 // can be told to produce backprop for this layer if needed.
147 bool Network::SetupNeedsBackprop(bool needs_backprop) {
148  needs_to_backprop_ = needs_backprop;
149  return needs_backprop || num_weights_ > 0;
150 }
151 
152 // Writes to the given file. Returns false in case of error.
153 bool Network::Serialize(TFile* fp) const {
154  inT8 data = NT_NONE;
155  if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
156  STRING type_name = kTypeNames[type_];
157  if (!type_name.Serialize(fp)) return false;
158  data = training_;
159  if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
160  data = needs_to_backprop_;
161  if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
162  if (fp->FWrite(&network_flags_, sizeof(network_flags_), 1) != 1) return false;
163  if (fp->FWrite(&ni_, sizeof(ni_), 1) != 1) return false;
164  if (fp->FWrite(&no_, sizeof(no_), 1) != 1) return false;
165  if (fp->FWrite(&num_weights_, sizeof(num_weights_), 1) != 1) return false;
166  if (!name_.Serialize(fp)) return false;
167  return true;
168 }
169 
170 // Reads from the given file. Returns false in case of error.
171 // Should be overridden by subclasses, but NOT called by their DeSerialize.
173  inT8 data = 0;
174  if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
175  if (data == NT_NONE) {
176  STRING type_name;
177  if (!type_name.DeSerialize(fp)) return false;
178  for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
179  }
180  if (data == NT_COUNT) {
181  tprintf("Invalid network layer type:%s\n", type_name.string());
182  return false;
183  }
184  }
185  type_ = static_cast<NetworkType>(data);
186  if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
188  if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
189  needs_to_backprop_ = data != 0;
190  if (fp->FReadEndian(&network_flags_, sizeof(network_flags_), 1) != 1)
191  return false;
192  if (fp->FReadEndian(&ni_, sizeof(ni_), 1) != 1) return false;
193  if (fp->FReadEndian(&no_, sizeof(no_), 1) != 1) return false;
194  if (fp->FReadEndian(&num_weights_, sizeof(num_weights_), 1) != 1)
195  return false;
196  if (!name_.DeSerialize(fp)) return false;
197  return true;
198 }
199 
200 // Reads from the given file. Returns NULL in case of error.
201 // Determines the type of the serialized class and calls its DeSerialize
202 // on a new object of the appropriate type, which is returned.
204  Network stub;
205  if (!stub.DeSerialize(fp)) return NULL;
206  Network* network = NULL;
207  switch (stub.type_) {
208  case NT_CONVOLVE:
209  network = new Convolve(stub.name_, stub.ni_, 0, 0);
210  break;
211  case NT_INPUT:
212  network = new Input(stub.name_, stub.ni_, stub.no_);
213  break;
214  case NT_LSTM:
215  case NT_LSTM_SOFTMAX:
217  case NT_LSTM_SUMMARY:
218  network =
219  new LSTM(stub.name_, stub.ni_, stub.no_, stub.no_, false, stub.type_);
220  break;
221  case NT_MAXPOOL:
222  network = new Maxpool(stub.name_, stub.ni_, 0, 0);
223  break;
224  // All variants of Parallel.
225  case NT_PARALLEL:
226  case NT_REPLICATED:
227  case NT_PAR_RL_LSTM:
228  case NT_PAR_UD_LSTM:
229  case NT_PAR_2D_LSTM:
230  network = new Parallel(stub.name_, stub.type_);
231  break;
232  case NT_RECONFIG:
233  network = new Reconfig(stub.name_, stub.ni_, 0, 0);
234  break;
235  // All variants of reversed.
236  case NT_XREVERSED:
237  case NT_YREVERSED:
238  case NT_XYTRANSPOSE:
239  network = new Reversed(stub.name_, stub.type_);
240  break;
241  case NT_SERIES:
242  network = new Series(stub.name_);
243  break;
244  case NT_TENSORFLOW:
245 #ifdef INCLUDE_TENSORFLOW
246  network = new TFNetwork(stub.name_);
247 #else
248  tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
249  return NULL;
250 #endif
251  break;
252  // All variants of FullyConnected.
253  case NT_SOFTMAX:
254  case NT_SOFTMAX_NO_CTC:
255  case NT_RELU:
256  case NT_TANH:
257  case NT_LINEAR:
258  case NT_LOGISTIC:
259  case NT_POSCLIP:
260  case NT_SYMCLIP:
261  network = new FullyConnected(stub.name_, stub.ni_, stub.no_, stub.type_);
262  break;
263  default:
264  return NULL;
265  }
266  network->training_ = stub.training_;
267  network->needs_to_backprop_ = stub.needs_to_backprop_;
268  network->network_flags_ = stub.network_flags_;
269  network->num_weights_ = stub.num_weights_;
270  if (!network->DeSerialize(fp)) {
271  delete network;
272  return NULL;
273  }
274  return network;
275 }
276 
277 // Returns a random number in [-range, range].
278 double Network::Random(double range) {
279  ASSERT_HOST(randomizer_ != NULL);
280  return randomizer_->SignedRand(range);
281 }
282 
283 // === Debug image display methods. ===
284 // Displays the image of the matrix to the forward window.
285 void Network::DisplayForward(const NetworkIO& matrix) {
286 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
287  Pix* image = matrix.ToPix();
288  ClearWindow(false, name_.string(), pixGetWidth(image),
289  pixGetHeight(image), &forward_win_);
290  DisplayImage(image, forward_win_);
291  forward_win_->Update();
292 #endif // GRAPHICS_DISABLED
293 }
294 
295 // Displays the image of the matrix to the backward window.
296 void Network::DisplayBackward(const NetworkIO& matrix) {
297 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
298  Pix* image = matrix.ToPix();
299  STRING window_name = name_ + "-back";
300  ClearWindow(false, window_name.string(), pixGetWidth(image),
301  pixGetHeight(image), &backward_win_);
302  DisplayImage(image, backward_win_);
304 #endif // GRAPHICS_DISABLED
305 }
306 
307 #ifndef GRAPHICS_DISABLED
308 // Creates the window if needed, otherwise clears it.
309 void Network::ClearWindow(bool tess_coords, const char* window_name,
310  int width, int height, ScrollView** window) {
311  if (*window == NULL) {
312  int min_size = MIN(width, height);
313  if (min_size < kMinWinSize) {
314  if (min_size < 1) min_size = 1;
315  width = width * kMinWinSize / min_size;
316  height = height * kMinWinSize / min_size;
317  }
318  width += kXWinFrameSize;
319  height += kYWinFrameSize;
320  if (width > kMaxWinSize) width = kMaxWinSize;
321  if (height > kMaxWinSize) height = kMaxWinSize;
322  *window = new ScrollView(window_name, 80, 100, width, height, width, height,
323  tess_coords);
324  tprintf("Created window %s of size %d, %d\n", window_name, width, height);
325  } else {
326  (*window)->Clear();
327  }
328 }
329 
330 // Displays the pix in the given window. and returns the height of the pix.
331 // The pix is pixDestroyed.
332 int Network::DisplayImage(Pix* pix, ScrollView* window) {
333  int height = pixGetHeight(pix);
334  window->Image(pix, 0, 0);
335  pixDestroy(&pix);
336  return height;
337 }
338 #endif // GRAPHICS_DISABLED
339 
340 } // namespace tesseract.
virtual int InitWeights(float range, TRand *randomizer)
Definition: network.cpp:132
bool needs_to_backprop_
Definition: network.h:287
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:140
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:285
bool DeSerialize(bool swap, FILE *fp)
Definition: strngs.cpp:163
const int kXWinFrameSize
Definition: network.cpp:53
ScrollView * backward_win_
Definition: network.h:296
NetworkType type() const
Definition: network.h:112
#define tprintf(...)
Definition: tprintf.h:31
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:203
const char * string() const
Definition: strngs.cpp:198
int FReadEndian(void *buffer, int size, int count)
Definition: serialis.cpp:97
TrainingState
Definition: network.h:92
TRand * randomizer_
Definition: network.h:297
#define ASSERT_HOST(x)
Definition: errcode.h:84
virtual void SetNetworkFlags(uinT32 flags)
Definition: network.cpp:126
inT32 network_flags_
Definition: network.h:288
const int kYWinFrameSize
Definition: network.cpp:54
virtual bool DeSerialize(TFile *fp)
Definition: network.cpp:172
uint32_t uinT32
Definition: host.h:39
TrainingState training_
Definition: network.h:286
Definition: strngs.h:45
static void Update()
Definition: scrollview.cpp:715
bool Serialize(FILE *fp) const
Definition: strngs.cpp:148
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:309
double Random(double range)
Definition: network.cpp:278
static char const *const kTypeNames[NT_COUNT]
Definition: network.h:300
int FWrite(const void *buffer, int size, int count)
Definition: serialis.cpp:148
const int kMinWinSize
Definition: network.cpp:50
ScrollView * forward_win_
Definition: network.h:295
NetworkType
Definition: network.h:43
int8_t inT8
Definition: host.h:34
NetworkType type_
Definition: network.h:285
#define MIN(x, y)
Definition: ndminx.h:28
virtual ~Network()
Definition: network.cpp:100
const STRING & name() const
Definition: network.h:138
void Image(struct Pix *image, int x_pos, int y_pos)
Definition: scrollview.cpp:773
double SignedRand(double range)
Definition: helpers.h:60
static int DisplayImage(Pix *pix, ScrollView *window)
Definition: network.cpp:332
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153
Pix * ToPix() const
Definition: networkio.cpp:286
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:296
inT32 num_weights_
Definition: network.h:291
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:112
virtual bool SetupNeedsBackprop(bool needs_backprop)
Definition: network.cpp:147
const int kMaxWinSize
Definition: network.cpp:51
int FRead(void *buffer, int size, int count)
Definition: serialis.cpp:108