tesseract  4.00.00dev
network.h
Go to the documentation of this file.
1 // File: network.h
3 // Description: Base class for neural network implementations.
4 // Author: Ray Smith
5 // Created: Wed May 01 16:38:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_NETWORK_H_
20 #define TESSERACT_LSTM_NETWORK_H_
21 
22 #include <stdio.h>
23 #include <cmath>
24 
25 #include "genericvector.h"
26 #include "helpers.h"
27 #include "matrix.h"
28 #include "networkio.h"
29 #include "serialis.h"
30 #include "static_shape.h"
31 #include "tprintf.h"
32 
33 struct Pix;
34 class ScrollView;
35 class TBOX;
36 
37 namespace tesseract {
38 
39 class ImageData;
40 class NetworkScratch;
41 
42 // Enum to store the run-time type of a Network. Keep in sync with kTypeNames.
44  NT_NONE, // The naked base class.
45  NT_INPUT, // Inputs from an image.
46  // Plumbing networks combine other networks or rearrange the inputs.
47  NT_CONVOLVE, // Duplicates inputs in a sliding window neighborhood.
48  NT_MAXPOOL, // Chooses the max result from a rectangle.
49  NT_PARALLEL, // Runs networks in parallel.
50  NT_REPLICATED, // Runs identical networks in parallel.
51  NT_PAR_RL_LSTM, // Runs LTR and RTL LSTMs in parallel.
52  NT_PAR_UD_LSTM, // Runs Up and Down LSTMs in parallel.
53  NT_PAR_2D_LSTM, // Runs 4 LSTMs in parallel.
54  NT_SERIES, // Executes a sequence of layers.
55  NT_RECONFIG, // Scales the time/y size but makes the output deeper.
56  NT_XREVERSED, // Reverses the x direction of the inputs/outputs.
57  NT_YREVERSED, // Reverses the y-direction of the inputs/outputs.
58  NT_XYTRANSPOSE, // Transposes x and y (for just a single op).
59  // Functional networks actually calculate stuff.
60  NT_LSTM, // Long-Short-Term-Memory block.
61  NT_LSTM_SUMMARY, // LSTM that only keeps its last output.
62  NT_LOGISTIC, // Fully connected logistic nonlinearity.
63  NT_POSCLIP, // Fully connected rect lin version of logistic.
64  NT_SYMCLIP, // Fully connected rect lin version of tanh.
65  NT_TANH, // Fully connected with tanh nonlinearity.
66  NT_RELU, // Fully connected with rectifier nonlinearity.
67  NT_LINEAR, // Fully connected with no nonlinearity.
68  NT_SOFTMAX, // Softmax uses exponential normalization, with CTC.
69  NT_SOFTMAX_NO_CTC, // Softmax uses exponential normalization, no CTC.
70  // The SOFTMAX LSTMs both have an extra softmax layer on top, but inside, with
71  // the outputs fed back to the input of the LSTM at the next timestep.
72  // The ENCODED version binary encodes the softmax outputs, providing log2 of
73  // the number of outputs as additional inputs, and the other version just
74  // provides all the softmax outputs as additional inputs.
75  NT_LSTM_SOFTMAX, // 1-d LSTM with built-in fully connected softmax.
76  NT_LSTM_SOFTMAX_ENCODED, // 1-d LSTM with built-in binary encoded softmax.
77  // A TensorFlow graph encapsulated as a Tesseract network.
79 
80  NT_COUNT // Array size.
81 };
82 
83 // Enum of Network behavior flags. Can in theory be set for each individual
84 // network element.
86  // Network forward/backprop behavior.
87  NF_LAYER_SPECIFIC_LR = 64, // Separate learning rate for each layer.
88  NF_ADA_GRAD = 128, // Weight-specific learning rate.
89 };
90 
91 // State of training and desired state used in SetEnableTraining.
93  // Valid states of training_.
94  TS_DISABLED, // Disabled permanently.
95  TS_ENABLED, // Enabled for backprop and to write a training dump.
96  // Re-enable from ANY disabled state.
97  TS_TEMP_DISABLE, // Temporarily disabled to write a recognition dump.
98  // Valid only for SetEnableTraining.
99  TS_RE_ENABLE, // Re-Enable from TS_TEMP_DISABLE, but not TS_DISABLED.
100 };
101 
102 // Base class for network types. Not quite an abstract base class, but almost.
103 // Most of the time no isolated Network exists, except prior to
104 // deserialization.
105 class Network {
106  public:
107  Network();
108  Network(NetworkType type, const STRING& name, int ni, int no);
109  virtual ~Network();
110 
111  // Accessors.
112  NetworkType type() const {
113  return type_;
114  }
115  bool IsTraining() const { return training_ == TS_ENABLED; }
116  bool needs_to_backprop() const {
117  return needs_to_backprop_;
118  }
119  int num_weights() const { return num_weights_; }
120  int NumInputs() const {
121  return ni_;
122  }
123  int NumOutputs() const {
124  return no_;
125  }
126  // Returns the required shape input to the network.
127  virtual StaticShape InputShape() const {
128  StaticShape result;
129  return result;
130  }
131  // Returns the shape output from the network given an input shape (which may
132  // be partially unknown ie zero).
133  virtual StaticShape OutputShape(const StaticShape& input_shape) const {
134  StaticShape result(input_shape);
135  result.set_depth(no_);
136  return result;
137  }
138  const STRING& name() const {
139  return name_;
140  }
141  virtual STRING spec() const {
142  return "?";
143  }
144  bool TestFlag(NetworkFlags flag) const {
145  return (network_flags_ & flag) != 0;
146  }
147 
148  // Initialization and administrative functions that are mostly provided
149  // by Plumbing.
150  // Returns true if the given type is derived from Plumbing, and thus contains
151  // multiple sub-networks that can have their own learning rate.
152  virtual bool IsPlumbingType() const { return false; }
153 
154  // Suspends/Enables/Permanently disables training by setting the training_
155  // flag. Serialize and DeSerialize only operate on the run-time data if state
156  // is TS_DISABLED or TS_TEMP_DISABLE. Specifying TS_TEMP_DISABLE will
157  // temporarily disable layers in state TS_ENABLED, allowing a trainer to
158  // serialize as if it were a recognizer.
159  // TS_RE_ENABLE will re-enable layers that were previously in any disabled
160  // state. If in TS_TEMP_DISABLE then the flag is just changed, but if in
161  // TS_DISABLED, the deltas in the weight matrices are reinitialized so that a
162  // recognizer can be converted back to a trainer.
163  virtual void SetEnableTraining(TrainingState state);
164 
165  // Sets flags that control the action of the network. See NetworkFlags enum
166  // for bit values.
167  virtual void SetNetworkFlags(uinT32 flags);
168 
169  // Sets up the network for training. Initializes weights using weights of
170  // scale `range` picked according to the random number generator `randomizer`.
171  // Note that randomizer is a borrowed pointer that should outlive the network
172  // and should not be deleted by any of the networks.
173  // Returns the number of weights initialized.
174  virtual int InitWeights(float range, TRand* randomizer);
175 
176  // Converts a float network to an int network.
177  virtual void ConvertToInt() {}
178 
179  // Provides a pointer to a TRand for any networks that care to use it.
180  // Note that randomizer is a borrowed pointer that should outlive the network
181  // and should not be deleted by any of the networks.
182  virtual void SetRandomizer(TRand* randomizer);
183 
184  // Sets needs_to_backprop_ to needs_backprop and returns true if
185  // needs_backprop || any weights in this network so the next layer forward
186  // can be told to produce backprop for this layer if needed.
187  virtual bool SetupNeedsBackprop(bool needs_backprop);
188 
189  // Returns the most recent reduction factor that the network applied to the
190  // time sequence. Assumes that any 2-d is already eliminated. Used for
191  // scaling bounding boxes of truth data and calculating result bounding boxes.
192  // WARNING: if GlobalMinimax is used to vary the scale, this will return
193  // the last used scale factor. Call it before any forward, and it will return
194  // the minimum scale factor of the paths through the GlobalMinimax.
195  virtual int XScaleFactor() const {
196  return 1;
197  }
198 
199  // Provides the (minimum) x scale factor to the network (of interest only to
200  // input units) so they can determine how to scale bounding boxes.
201  virtual void CacheXScaleFactor(int factor) {}
202 
203  // Provides debug output on the weights.
204  virtual void DebugWeights() {
205  tprintf("Must override Network::DebugWeights for type %d\n", type_);
206  }
207 
208  // Writes to the given file. Returns false in case of error.
209  // Should be overridden by subclasses, but called by their Serialize.
210  virtual bool Serialize(TFile* fp) const;
211  // Reads from the given file. Returns false in case of error.
212  // Should be overridden by subclasses, but NOT called by their DeSerialize.
213  virtual bool DeSerialize(TFile* fp);
214 
215  // Updates the weights using the given learning rate and momentum.
216  // num_samples is the quotient to be used in the adagrad computation iff
217  // use_ada_grad_ is true.
218  virtual void Update(float learning_rate, float momentum, int num_samples) {}
219  // Sums the products of weight updates in *this and other, splitting into
220  // positive (same direction) in *same and negative (different direction) in
221  // *changed.
222  virtual void CountAlternators(const Network& other, double* same,
223  double* changed) const {}
224 
225  // Reads from the given file. Returns NULL in case of error.
226  // Determines the type of the serialized class and calls its DeSerialize
227  // on a new object of the appropriate type, which is returned.
228  static Network* CreateFromFile(TFile* fp);
229 
230  // Runs forward propagation of activations on the input line.
231  // Note that input and output are both 2-d arrays.
232  // The 1st index is the time element. In a 1-d network, it might be the pixel
233  // position on the textline. In a 2-d network, the linearization is defined
234  // by the stride_map. (See networkio.h).
235  // The 2nd index of input is the network inputs/outputs, and the dimension
236  // of the input must match NumInputs() of this network.
237  // The output array will be resized as needed so that its 1st dimension is
238  // always equal to the number of output values, and its second dimension is
239  // always NumOutputs(). Note that all this detail is encapsulated away inside
240  // NetworkIO, as are the internals of the scratch memory space used by the
241  // network. See networkscratch.h for that.
242  // If input_transpose is not NULL, then it contains the transpose of input,
243  // and the caller guarantees that it will still be valid on the next call to
244  // backward. The callee is therefore at liberty to save the pointer and
245  // reference it on a call to backward. This is a bit ugly, but it makes it
246  // possible for a replicating parallel to calculate the input transpose once
247  // instead of all the replicated networks having to do it.
248  virtual void Forward(bool debug, const NetworkIO& input,
249  const TransposedArray* input_transpose,
250  NetworkScratch* scratch, NetworkIO* output) {
251  tprintf("Must override Network::Forward for type %d\n", type_);
252  }
253 
254  // Runs backward propagation of errors on fwdX_deltas.
255  // Note that fwd_deltas and back_deltas are both 2-d arrays as with Forward.
256  // Returns false if back_deltas was not set, due to there being no point in
257  // propagating further backwards. Thus most complete networks will always
258  // return false from Backward!
259  virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
260  NetworkScratch* scratch,
261  NetworkIO* back_deltas) {
262  tprintf("Must override Network::Backward for type %d\n", type_);
263  return false;
264  }
265 
266  // === Debug image display methods. ===
267  // Displays the image of the matrix to the forward window.
268  void DisplayForward(const NetworkIO& matrix);
269  // Displays the image of the matrix to the backward window.
270  void DisplayBackward(const NetworkIO& matrix);
271 
272  // Creates the window if needed, otherwise clears it.
273  static void ClearWindow(bool tess_coords, const char* window_name,
274  int width, int height, ScrollView** window);
275 
276  // Displays the pix in the given window. and returns the height of the pix.
277  // The pix is pixDestroyed.
278  static int DisplayImage(Pix* pix, ScrollView* window);
279 
280  protected:
281  // Returns a random number in [-range, range].
282  double Random(double range);
283 
284  protected:
285  NetworkType type_; // Type of the derived network class.
286  TrainingState training_; // Are we currently training?
287  bool needs_to_backprop_; // This network needs to output back_deltas.
288  inT32 network_flags_; // Behavior control flags in NetworkFlags.
289  inT32 ni_; // Number of input values.
290  inT32 no_; // Number of output values.
291  inT32 num_weights_; // Number of weights in this and sub-network.
292  STRING name_; // A unique name for this layer.
293 
294  // NOT-serialized debug data.
295  ScrollView* forward_win_; // Recognition debug display window.
296  ScrollView* backward_win_; // Training debug display window.
297  TRand* randomizer_; // Random number generator.
298 
299  // Static serialized name/type_ mapping. Keep in sync with NetworkType.
300  static char const* const kTypeNames[NT_COUNT];
301 };
302 
303 
304 } // namespace tesseract.
305 
306 #endif // TESSERACT_LSTM_NETWORK_H_
virtual int InitWeights(float range, TRand *randomizer)
Definition: network.cpp:132
virtual int XScaleFactor() const
Definition: network.h:195
virtual bool IsPlumbingType() const
Definition: network.h:152
bool needs_to_backprop_
Definition: network.h:287
int num_weights() const
Definition: network.h:119
int32_t inT32
Definition: host.h:38
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:140
virtual void Update(float learning_rate, float momentum, int num_samples)
Definition: network.h:218
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:285
ScrollView * backward_win_
Definition: network.h:296
NetworkType type() const
Definition: network.h:112
#define tprintf(...)
Definition: tprintf.h:31
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:203
bool needs_to_backprop() const
Definition: network.h:116
bool IsTraining() const
Definition: network.h:115
TrainingState
Definition: network.h:92
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
Definition: network.h:248
TRand * randomizer_
Definition: network.h:297
virtual void CacheXScaleFactor(int factor)
Definition: network.h:201
virtual void SetNetworkFlags(uinT32 flags)
Definition: network.cpp:126
inT32 network_flags_
Definition: network.h:288
virtual bool DeSerialize(TFile *fp)
Definition: network.cpp:172
uint32_t uinT32
Definition: host.h:39
virtual StaticShape InputShape() const
Definition: network.h:127
TrainingState training_
Definition: network.h:286
virtual void ConvertToInt()
Definition: network.h:177
virtual void DebugWeights()
Definition: network.h:204
Definition: strngs.h:45
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:309
double Random(double range)
Definition: network.cpp:278
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
NetworkFlags
Definition: network.h:85
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
Definition: network.h:259
static char const *const kTypeNames[NT_COUNT]
Definition: network.h:300
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: network.h:133
ScrollView * forward_win_
Definition: network.h:295
Definition: rect.h:30
NetworkType
Definition: network.h:43
NetworkType type_
Definition: network.h:285
virtual ~Network()
Definition: network.cpp:100
int NumInputs() const
Definition: network.h:120
const STRING & name() const
Definition: network.h:138
virtual STRING spec() const
Definition: network.h:141
static int DisplayImage(Pix *pix, ScrollView *window)
Definition: network.cpp:332
int NumOutputs() const
Definition: network.h:123
virtual void CountAlternators(const Network &other, double *same, double *changed) const
Definition: network.h:222
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:296
inT32 num_weights_
Definition: network.h:291
virtual void SetEnableTraining(TrainingState state)
Definition: network.cpp:112
virtual bool SetupNeedsBackprop(bool needs_backprop)
Definition: network.cpp:147
void set_depth(int value)
Definition: static_shape.h:47