tesseract  4.00.00dev
input.h
Go to the documentation of this file.
1 // File: input.h
3 // Description: Input layer class for neural network implementations.
4 // Author: Ray Smith
5 // Created: Thu Mar 13 08:56:26 PDT 2014
6 //
7 // (C) Copyright 2014, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #ifndef TESSERACT_LSTM_INPUT_H_
20 #define TESSERACT_LSTM_INPUT_H_
21 
22 #include "network.h"
23 
24 class ScrollView;
25 
26 namespace tesseract {
27 
28 class Input : public Network {
29  public:
30  Input(const STRING& name, int ni, int no);
31  Input(const STRING& name, const StaticShape& shape);
32  virtual ~Input();
33 
34  virtual STRING spec() const {
35  STRING spec;
36  spec.add_str_int("", shape_.batch());
37  spec.add_str_int(",", shape_.height());
38  spec.add_str_int(",", shape_.width());
39  spec.add_str_int(",", shape_.depth());
40  return spec;
41  }
42 
43  // Returns the required shape input to the network.
44  virtual StaticShape InputShape() const { return shape_; }
45  // Returns the shape output from the network given an input shape (which may
46  // be partially unknown ie zero).
47  virtual StaticShape OutputShape(const StaticShape& input_shape) const {
48  return shape_;
49  }
50  // Writes to the given file. Returns false in case of error.
51  // Should be overridden by subclasses, but called by their Serialize.
52  virtual bool Serialize(TFile* fp) const;
53  // Reads from the given file. Returns false in case of error.
54  virtual bool DeSerialize(TFile* fp);
55 
56  // Returns an integer reduction factor that the network applies to the
57  // time sequence. Assumes that any 2-d is already eliminated. Used for
58  // scaling bounding boxes of truth data.
59  // WARNING: if GlobalMinimax is used to vary the scale, this will return
60  // the last used scale factor. Call it before any forward, and it will return
61  // the minimum scale factor of the paths through the GlobalMinimax.
62  virtual int XScaleFactor() const;
63 
64  // Provides the (minimum) x scale factor to the network (of interest only to
65  // input units) so they can determine how to scale bounding boxes.
66  virtual void CacheXScaleFactor(int factor);
67 
68  // Runs forward propagation of activations on the input line.
69  // See Network for a detailed discussion of the arguments.
70  virtual void Forward(bool debug, const NetworkIO& input,
71  const TransposedArray* input_transpose,
72  NetworkScratch* scratch, NetworkIO* output);
73 
74  // Runs backward propagation of errors on the deltas line.
75  // See Network for a detailed discussion of the arguments.
76  virtual bool Backward(bool debug, const NetworkIO& fwd_deltas,
77  NetworkScratch* scratch,
78  NetworkIO* back_deltas);
79  // Creates and returns a Pix of appropriate size for the network from the
80  // image_data. If non-null, *image_scale returns the image scale factor used.
81  // Returns nullptr on error.
82  /* static */
83  static Pix* PrepareLSTMInputs(const ImageData& image_data,
84  const Network* network, int min_width,
85  TRand* randomizer, float* image_scale);
86  // Converts the given pix to a NetworkIO of height and depth appropriate to
87  // the given StaticShape:
88  // If depth == 3, convert to 24 bit color, otherwise normalized grey.
89  // Scale to target height, if the shape's height is > 1, or its depth if the
90  // height == 1. If height == 0 then no scaling.
91  // NOTE: It isn't safe for multiple threads to call this on the same pix.
92  static void PreparePixInput(const StaticShape& shape, const Pix* pix,
93  TRand* randomizer, NetworkIO* input);
94 
95  private:
96  // Input shape determines how images are dealt with.
97  StaticShape shape_;
98  // Cached total network x scale factor for scaling bounding boxes.
99  int cached_x_scale_;
100 };
101 
102 } // namespace tesseract.
103 
104 #endif // TESSERACT_LSTM_INPUT_H_
105 
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
Definition: input.cpp:78
virtual StaticShape OutputShape(const StaticShape &input_shape) const
Definition: input.h:47
static void PreparePixInput(const StaticShape &shape, const Pix *pix, TRand *randomizer, NetworkIO *input)
Definition: input.cpp:117
static Pix * PrepareLSTMInputs(const ImageData &image_data, const Network *network, int min_width, TRand *randomizer, float *image_scale)
Definition: input.cpp:89
virtual bool DeSerialize(TFile *fp)
Definition: input.cpp:51
virtual int XScaleFactor() const
Definition: input.cpp:58
Input(const STRING &name, int ni, int no)
Definition: input.cpp:31
virtual void CacheXScaleFactor(int factor)
Definition: input.cpp:64
virtual ~Input()
Definition: input.cpp:40
Definition: strngs.h:45
virtual StaticShape InputShape() const
Definition: input.h:44
virtual bool Serialize(TFile *fp) const
Definition: input.cpp:44
const STRING & name() const
Definition: network.h:138
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
Definition: input.cpp:70
virtual STRING spec() const
Definition: input.h:34