tesseract  4.00.00dev
fullyconnected.cpp
Go to the documentation of this file.
1 // File: fullyconnected.cpp
3 // Description: Simple feed-forward layer with various non-linearities.
4 // Author: Ray Smith
5 // Created: Wed Feb 26 14:49:15 PST 2014
6 //
7 // (C) Copyright 2014, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #include "fullyconnected.h"
20 
21 #ifdef _OPENMP
22 #include <omp.h>
23 #endif
24 #include <stdio.h>
25 #include <stdlib.h>
26 
27 #include "functions.h"
28 #include "networkscratch.h"
29 
30 // Number of threads to use for parallel calculation of Forward and Backward.
31 const int kNumThreads = 4;
32 
33 namespace tesseract {
34 
35 FullyConnected::FullyConnected(const STRING& name, int ni, int no,
36  NetworkType type)
37  : Network(type, name, ni, no), external_source_(NULL), int_mode_(false) {
38 }
39 
41 }
42 
43 // Returns the shape output from the network given an input shape (which may
44 // be partially unknown ie zero).
46  LossType loss_type = LT_NONE;
47  if (type_ == NT_SOFTMAX)
48  loss_type = LT_CTC;
49  else if (type_ == NT_SOFTMAX_NO_CTC)
50  loss_type = LT_SOFTMAX;
51  else if (type_ == NT_LOGISTIC)
52  loss_type = LT_LOGISTIC;
53  StaticShape result(input_shape);
54  result.set_depth(no_);
55  result.set_loss_type(loss_type);
56  return result;
57 }
58 
59 // Suspends/Enables training by setting the training_ flag.
61  if (state == TS_RE_ENABLE) {
62  // Enable only from temp disabled.
64  } else if (state == TS_TEMP_DISABLE) {
65  // Temp disable only from enabled.
66  if (training_ == TS_ENABLED) training_ = state;
67  } else {
68  if (state == TS_ENABLED && training_ != TS_ENABLED)
70  training_ = state;
71  }
72 }
73 
74 // Sets up the network for training. Initializes weights using weights of
75 // scale `range` picked according to the random number generator `randomizer`.
76 int FullyConnected::InitWeights(float range, TRand* randomizer) {
77  Network::SetRandomizer(randomizer);
79  range, randomizer);
80  return num_weights_;
81 }
82 
83 // Converts a float network to an int network.
86 }
87 
88 // Provides debug output on the weights.
91 }
92 
93 // Writes to the given file. Returns false in case of error.
95  if (!Network::Serialize(fp)) return false;
96  if (!weights_.Serialize(IsTraining(), fp)) return false;
97  return true;
98 }
99 
100 // Reads from the given file. Returns false in case of error.
102  return weights_.DeSerialize(IsTraining(), fp);
103 }
104 
105 // Runs forward propagation of activations on the input line.
106 // See NetworkCpp for a detailed discussion of the arguments.
107 void FullyConnected::Forward(bool debug, const NetworkIO& input,
108  const TransposedArray* input_transpose,
109  NetworkScratch* scratch, NetworkIO* output) {
110  int width = input.Width();
111  if (type_ == NT_SOFTMAX)
112  output->ResizeFloat(input, no_);
113  else
114  output->Resize(input, no_);
115  SetupForward(input, input_transpose);
120  for (int i = 0; i < temp_lines.size(); ++i) {
121  temp_lines[i].Init(no_, scratch);
122  curr_input[i].Init(ni_, scratch);
123  }
124 #ifdef _OPENMP
125 #pragma omp parallel for num_threads(kNumThreads)
126  for (int t = 0; t < width; ++t) {
127  // Thread-local pointer to temporary storage.
128  int thread_id = omp_get_thread_num();
129 #else
130  for (int t = 0; t < width; ++t) {
131  // Thread-local pointer to temporary storage.
132  int thread_id = 0;
133 #endif
134  double* temp_line = temp_lines[thread_id];
135  const double* d_input = NULL;
136  const inT8* i_input = NULL;
137  if (input.int_mode()) {
138  i_input = input.i(t);
139  } else {
140  input.ReadTimeStep(t, curr_input[thread_id]);
141  d_input = curr_input[thread_id];
142  }
143  ForwardTimeStep(d_input, i_input, t, temp_line);
144  output->WriteTimeStep(t, temp_line);
145  if (IsTraining() && type_ != NT_SOFTMAX) {
146  acts_.CopyTimeStepFrom(t, *output, t);
147  }
148  }
149  // Zero all the elements that are in the padding around images that allows
150  // multiple different-sized images to exist in a single array.
151  // acts_ is only used if this is not a softmax op.
152  if (IsTraining() && type_ != NT_SOFTMAX) {
154  }
155  output->ZeroInvalidElements();
156 #if DEBUG_DETAIL > 0
157  tprintf("F Output:%s\n", name_.string());
158  output->Print(10);
159 #endif
160  if (debug) DisplayForward(*output);
161 }
162 
163 // Components of Forward so FullyConnected can be reused inside LSTM.
165  const TransposedArray* input_transpose) {
166  // Softmax output is always float, so save the input type.
167  int_mode_ = input.int_mode();
168  if (IsTraining()) {
169  acts_.Resize(input, no_);
170  // Source_ is a transposed copy of input. It isn't needed if provided.
171  external_source_ = input_transpose;
172  if (external_source_ == NULL) source_t_.ResizeNoInit(ni_, input.Width());
173  }
174 }
175 
176 void FullyConnected::ForwardTimeStep(const double* d_input, const inT8* i_input,
177  int t, double* output_line) {
178  // input is copied to source_ line-by-line for cache coherency.
179  if (IsTraining() && external_source_ == NULL && d_input != NULL)
180  source_t_.WriteStrided(t, d_input);
181  if (d_input != NULL)
182  weights_.MatrixDotVector(d_input, output_line);
183  else
184  weights_.MatrixDotVector(i_input, output_line);
185  if (type_ == NT_TANH) {
186  FuncInplace<GFunc>(no_, output_line);
187  } else if (type_ == NT_LOGISTIC) {
188  FuncInplace<FFunc>(no_, output_line);
189  } else if (type_ == NT_POSCLIP) {
190  FuncInplace<ClipFFunc>(no_, output_line);
191  } else if (type_ == NT_SYMCLIP) {
192  FuncInplace<ClipGFunc>(no_, output_line);
193  } else if (type_ == NT_RELU) {
194  FuncInplace<Relu>(no_, output_line);
195  } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
196  SoftmaxInPlace(no_, output_line);
197  } else if (type_ != NT_LINEAR) {
198  ASSERT_HOST("Invalid fully-connected type!" == NULL);
199  }
200 }
201 
202 // Runs backward propagation of errors on the deltas line.
203 // See NetworkCpp for a detailed discussion of the arguments.
204 bool FullyConnected::Backward(bool debug, const NetworkIO& fwd_deltas,
205  NetworkScratch* scratch,
206  NetworkIO* back_deltas) {
207  if (debug) DisplayBackward(fwd_deltas);
208  back_deltas->Resize(fwd_deltas, ni_);
211  for (int i = 0; i < errors.size(); ++i) errors[i].Init(no_, scratch);
213  if (needs_to_backprop_) {
215  for (int i = 0; i < kNumThreads; ++i) temp_backprops[i].Init(ni_, scratch);
216  }
217  int width = fwd_deltas.Width();
219  errors_t.Init(no_, width, scratch);
220 #ifdef _OPENMP
221 #pragma omp parallel for num_threads(kNumThreads)
222  for (int t = 0; t < width; ++t) {
223  int thread_id = omp_get_thread_num();
224 #else
225  for (int t = 0; t < width; ++t) {
226  int thread_id = 0;
227 #endif
228  double* backprop = NULL;
229  if (needs_to_backprop_) backprop = temp_backprops[thread_id];
230  double* curr_errors = errors[thread_id];
231  BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
232  if (backprop != NULL) {
233  back_deltas->WriteTimeStep(t, backprop);
234  }
235  }
236  FinishBackward(*errors_t.get());
237  if (needs_to_backprop_) {
238  back_deltas->ZeroInvalidElements();
239  back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas);
240 #if DEBUG_DETAIL > 0
241  tprintf("F Backprop:%s\n", name_.string());
242  back_deltas->Print(10);
243 #endif
244  return true;
245  }
246  return false; // No point going further back.
247 }
248 
249 void FullyConnected::BackwardTimeStep(const NetworkIO& fwd_deltas, int t,
250  double* curr_errors,
251  TransposedArray* errors_t,
252  double* backprop) {
253  if (type_ == NT_TANH)
254  acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
255  else if (type_ == NT_LOGISTIC)
256  acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
257  else if (type_ == NT_POSCLIP)
258  acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
259  else if (type_ == NT_SYMCLIP)
260  acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
261  else if (type_ == NT_RELU)
262  acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
263  else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC ||
264  type_ == NT_LINEAR)
265  fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
266  else
267  ASSERT_HOST("Invalid fully-connected type!" == NULL);
268  // Generate backprop only if needed by the lower layer.
269  if (backprop != NULL) weights_.VectorDotMatrix(curr_errors, backprop);
270  errors_t->WriteStrided(t, curr_errors);
271 }
272 
274  if (external_source_ == NULL)
275  weights_.SumOuterTransposed(errors_t, source_t_, true);
276  else
277  weights_.SumOuterTransposed(errors_t, *external_source_, true);
278 }
279 
280 // Updates the weights using the given learning rate and momentum.
281 // num_samples is the quotient to be used in the adagrad computation iff
282 // use_ada_grad_ is true.
283 void FullyConnected::Update(float learning_rate, float momentum,
284  int num_samples) {
285  weights_.Update(learning_rate, momentum, num_samples);
286 }
287 
288 // Sums the products of weight updates in *this and other, splitting into
289 // positive (same direction) in *same and negative (different direction) in
290 // *changed.
291 void FullyConnected::CountAlternators(const Network& other, double* same,
292  double* changed) const {
293  ASSERT_HOST(other.type() == type_);
294  const FullyConnected* fc = static_cast<const FullyConnected*>(&other);
295  weights_.CountAlternators(fc->weights_, same, changed);
296 }
297 
298 } // namespace tesseract.
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
Definition: networkio.cpp:383
void ZeroInvalidElements()
Definition: networkio.cpp:88
void ReadTimeStep(int t, double *output) const
Definition: networkio.cpp:598
bool needs_to_backprop_
Definition: network.h:287
FullyConnected(const STRING &name, int ni, int no, NetworkType type)
void Debug2D(const char *msg)
int Width() const
Definition: networkio.h:107
virtual void Update(float learning_rate, float momentum, int num_samples)
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:140
void init_to_size(int size, T t)
void MatrixDotVector(const double *u, double *v) const
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:285
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
NetworkType type() const
Definition: network.h:112
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
void VectorDotMatrix(const double *u, double *v) const
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:37
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
bool IsTraining() const
Definition: network.h:115
int size() const
Definition: genericvector.h:72
bool int_mode() const
Definition: networkio.h:127
TrainingState
Definition: network.h:92
void set_loss_type(LossType value)
Definition: static_shape.h:49
#define ASSERT_HOST(x)
Definition: errcode.h:84
void FuncMultiply(const NetworkIO &v_io, int t, double *product)
Definition: networkio.h:259
virtual bool DeSerialize(TFile *fp)
TrainingState training_
Definition: network.h:286
Definition: strngs.h:45
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
virtual void SetEnableTraining(TrainingState state)
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void CopyWithNormalization(const NetworkIO &src, const NetworkIO &scale)
Definition: networkio.cpp:831
void SoftmaxInPlace(int n, T *inout)
Definition: functions.h:163
void Update(double learning_rate, double momentum, int num_samples)
void ResizeNoInit(int size1, int size2)
Definition: matrix.h:86
void ResizeFloat(const NetworkIO &src, int num_features)
Definition: networkio.h:52
NetworkType
Definition: network.h:43
int8_t inT8
Definition: host.h:34
const int kNumThreads
NetworkType type_
Definition: network.h:285
bool DeSerialize(bool training, TFile *fp)
int InitWeightsFloat(int no, int ni, bool ada_grad, float weight_range, TRand *randomizer)
void ForwardTimeStep(const double *d_input, const inT8 *i_input, int t, double *output_line)
void Print(int num) const
Definition: networkio.cpp:366
bool Serialize(bool training, TFile *fp) const
const inT8 * i(int t) const
Definition: networkio.h:123
void Init(int size1, int size2, NetworkScratch *scratch)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
const TransposedArray * external_source_
void WriteTimeStep(int t, const double *input)
Definition: networkio.cpp:645
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:296
virtual bool Serialize(TFile *fp) const
inT32 num_weights_
Definition: network.h:291
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
virtual void CountAlternators(const Network &other, double *same, double *changed) const
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
void FinishBackward(const TransposedArray &errors_t)
virtual int InitWeights(float range, TRand *randomizer)
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void set_depth(int value)
Definition: static_shape.h:47