tesseract  4.00.00dev
series.h
Go to the documentation of this file.
1 // File: series.h
3 // Description: Runs networks in series on the same input.
4 // Author: Ray Smith
5 // Created: Thu May 02 08:20:06 PST 2013
6 //
7 // (C) Copyright 2013, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_SERIES_H_
20 #define TESSERACT_LSTM_SERIES_H_
21 
22 #include "plumbing.h"
23 
24 namespace tesseract {
25 
26 // Runs two or more networks in series (layers) on the same input.
27 class Series : public Plumbing {
28  public:
29  // ni_ and no_ will be set by AddToStack.
30  explicit Series(const STRING& name);
31  virtual ~Series();
32 
33  // Returns the shape output from the network given an input shape (which may
34  // be partially unknown ie zero).
35  virtual StaticShape OutputShape(const StaticShape& input_shape) const;
36 
37  virtual STRING spec() const {
38  STRING spec("[");
39  for (int i = 0; i < stack_.size(); ++i)
40  spec += stack_[i]->spec();
41  spec += "]";
42  return spec;
43  }
44 
45  // Sets up the network for training. Initializes weights using weights of
46  // scale `range` picked according to the random number generator `randomizer`.
47  // Returns the number of weights initialized.
48  virtual int InitWeights(float range, TRand* randomizer);
49 
50  // Sets needs_to_backprop_ to needs_backprop and returns true if
51  // needs_backprop || any weights in this network so the next layer forward
52  // can be told to produce backprop for this layer if needed.
53  virtual bool SetupNeedsBackprop(bool needs_backprop);
54 
55  // Returns an integer reduction factor that the network applies to the
56  // time sequence. Assumes that any 2-d is already eliminated. Used for
57  // scaling bounding boxes of truth data.
58  // WARNING: if GlobalMinimax is used to vary the scale, this will return
59  // the last used scale factor. Call it before any forward, and it will return
60  // the minimum scale factor of the paths through the GlobalMinimax.
61  virtual int XScaleFactor() const;
62 
63  // Provides the (minimum) x scale factor to the network (of interest only to
64  // input units) so they can determine how to scale bounding boxes.
65  virtual void CacheXScaleFactor(int factor);
66 
67  // Runs forward propagation of activations on the input line.
68  // See Network for a detailed discussion of the arguments.
69  virtual void Forward(bool debug, const NetworkIO& input,
70  const TransposedArray* input_transpose,
71  NetworkScratch* scratch, NetworkIO* output);
72 
73  // Runs backward propagation of errors on the deltas line.
74  // See Network for a detailed discussion of the arguments.
75  virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
76  NetworkScratch* scratch,
77  NetworkIO* back_deltas);
78 
79  // Splits the series after the given index, returning the two parts and
80  // deletes itself. The first part, up to network with index last_start, goes
81  // into start, and the rest goes into end.
82  void SplitAt(int last_start, Series** start, Series** end);
83 
84  // Appends the elements of the src series to this, removing from src and
85  // deleting it.
86  void AppendSeries(Network* src);
87 };
88 
89 } // namespace tesseract.
90 
91 #endif // TESSERACT_LSTM_SERIES_H_
virtual int XScaleFactor() const
Definition: series.cpp:79
virtual void CacheXScaleFactor(int factor)
Definition: series.cpp:88
virtual bool SetupNeedsBackprop(bool needs_backprop)
Definition: series.cpp:66
void SplitAt(int last_start, Series **start, Series **end)
Definition: series.cpp:147
virtual ~Series()
Definition: series.cpp:33
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
Definition: series.cpp:94
void AppendSeries(Network *src)
Definition: series.cpp:177
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: series.cpp:38
PointerVector< Network > stack_
Definition: plumbing.h:133
Definition: strngs.h:45
virtual STRING spec() const
Definition: series.h:37
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
Definition: series.cpp:116
Series(const STRING &name)
Definition: series.cpp:29
const STRING & name() const
Definition: network.h:138
virtual int InitWeights(float range, TRand *randomizer)
Definition: series.cpp:50