tesseract  4.00.00dev
tesseract::FullyConnected Class Reference

#include <fullyconnected.h>

Inheritance diagram for tesseract::FullyConnected:
tesseract::Network

Public Member Functions

 FullyConnected (const STRING &name, int ni, int no, NetworkType type)
 
virtual ~FullyConnected ()
 
virtual StaticShape OutputShape (const StaticShape &input_shape) const
 
virtual STRING spec () const
 
void ChangeType (NetworkType type)
 
virtual void SetEnableTraining (TrainingState state)
 
virtual int InitWeights (float range, TRand *randomizer)
 
virtual void ConvertToInt ()
 
virtual void DebugWeights ()
 
virtual bool Serialize (TFile *fp) const
 
virtual bool DeSerialize (TFile *fp)
 
virtual void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
 
void SetupForward (const NetworkIO &input, const TransposedArray *input_transpose)
 
void ForwardTimeStep (const double *d_input, const inT8 *i_input, int t, double *output_line)
 
virtual bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
 
void BackwardTimeStep (const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
 
void FinishBackward (const TransposedArray &errors_t)
 
virtual void Update (float learning_rate, float momentum, int num_samples)
 
virtual void CountAlternators (const Network &other, double *same, double *changed) const
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
const STRINGname () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetNetworkFlags (uinT32 flags)
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor (int factor)
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Protected Attributes

WeightMatrix weights_
 
TransposedArray source_t_
 
const TransposedArrayexternal_source_
 
NetworkIO acts_
 
bool int_mode_
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
inT32 network_flags_
 
inT32 ni_
 
inT32 no_
 
inT32 num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 

Additional Inherited Members

- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 
- Protected Member Functions inherited from tesseract::Network
double Random (double range)
 
- Static Protected Attributes inherited from tesseract::Network
static char const *const kTypeNames [NT_COUNT]
 

Detailed Description

Definition at line 28 of file fullyconnected.h.

Constructor & Destructor Documentation

◆ FullyConnected()

tesseract::FullyConnected::FullyConnected ( const STRING name,
int  ni,
int  no,
NetworkType  type 
)

Definition at line 35 of file fullyconnected.cpp.

37  : Network(type, name, ni, no), external_source_(NULL), int_mode_(false) {
38 }
NetworkType type() const
Definition: network.h:112
const TransposedArray * external_source_

◆ ~FullyConnected()

tesseract::FullyConnected::~FullyConnected ( )
virtual

Definition at line 40 of file fullyconnected.cpp.

40  {
41 }

Member Function Documentation

◆ Backward()

bool tesseract::FullyConnected::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 204 of file fullyconnected.cpp.

206  {
207  if (debug) DisplayBackward(fwd_deltas);
208  back_deltas->Resize(fwd_deltas, ni_);
210  errors.init_to_size(kNumThreads, NetworkScratch::FloatVec());
211  for (int i = 0; i < errors.size(); ++i) errors[i].Init(no_, scratch);
213  if (needs_to_backprop_) {
214  temp_backprops.init_to_size(kNumThreads, NetworkScratch::FloatVec());
215  for (int i = 0; i < kNumThreads; ++i) temp_backprops[i].Init(ni_, scratch);
216  }
217  int width = fwd_deltas.Width();
218  NetworkScratch::GradientStore errors_t;
219  errors_t.Init(no_, width, scratch);
220 #ifdef _OPENMP
221 #pragma omp parallel for num_threads(kNumThreads)
222  for (int t = 0; t < width; ++t) {
223  int thread_id = omp_get_thread_num();
224 #else
225  for (int t = 0; t < width; ++t) {
226  int thread_id = 0;
227 #endif
228  double* backprop = NULL;
229  if (needs_to_backprop_) backprop = temp_backprops[thread_id];
230  double* curr_errors = errors[thread_id];
231  BackwardTimeStep(fwd_deltas, t, curr_errors, errors_t.get(), backprop);
232  if (backprop != NULL) {
233  back_deltas->WriteTimeStep(t, backprop);
234  }
235  }
236  FinishBackward(*errors_t.get());
237  if (needs_to_backprop_) {
238  back_deltas->ZeroInvalidElements();
239  back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas);
240 #if DEBUG_DETAIL > 0
241  tprintf("F Backprop:%s\n", name_.string());
242  back_deltas->Print(10);
243 #endif
244  return true;
245  }
246  return false; // No point going further back.
247 }
bool needs_to_backprop_
Definition: network.h:287
void init_to_size(int size, T t)
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
int size() const
Definition: genericvector.h:72
const int kNumThreads
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:296
void FinishBackward(const TransposedArray &errors_t)
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)

◆ BackwardTimeStep()

void tesseract::FullyConnected::BackwardTimeStep ( const NetworkIO fwd_deltas,
int  t,
double *  curr_errors,
TransposedArray errors_t,
double *  backprop 
)

Definition at line 249 of file fullyconnected.cpp.

252  {
253  if (type_ == NT_TANH)
254  acts_.FuncMultiply<GPrime>(fwd_deltas, t, curr_errors);
255  else if (type_ == NT_LOGISTIC)
256  acts_.FuncMultiply<FPrime>(fwd_deltas, t, curr_errors);
257  else if (type_ == NT_POSCLIP)
258  acts_.FuncMultiply<ClipFPrime>(fwd_deltas, t, curr_errors);
259  else if (type_ == NT_SYMCLIP)
260  acts_.FuncMultiply<ClipGPrime>(fwd_deltas, t, curr_errors);
261  else if (type_ == NT_RELU)
262  acts_.FuncMultiply<ReluPrime>(fwd_deltas, t, curr_errors);
263  else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC ||
264  type_ == NT_LINEAR)
265  fwd_deltas.ReadTimeStep(t, curr_errors); // fwd_deltas are the errors.
266  else
267  ASSERT_HOST("Invalid fully-connected type!" == NULL);
268  // Generate backprop only if needed by the lower layer.
269  if (backprop != NULL) weights_.VectorDotMatrix(curr_errors, backprop);
270  errors_t->WriteStrided(t, curr_errors);
271 }
void VectorDotMatrix(const double *u, double *v) const
#define ASSERT_HOST(x)
Definition: errcode.h:84
void FuncMultiply(const NetworkIO &v_io, int t, double *product)
Definition: networkio.h:259
NetworkType type_
Definition: network.h:285

◆ ChangeType()

void tesseract::FullyConnected::ChangeType ( NetworkType  type)
inline

Definition at line 60 of file fullyconnected.h.

60  {
61  type_ = type;
62  }
NetworkType type() const
Definition: network.h:112
NetworkType type_
Definition: network.h:285

◆ ConvertToInt()

void tesseract::FullyConnected::ConvertToInt ( )
virtual

Reimplemented from tesseract::Network.

Definition at line 84 of file fullyconnected.cpp.

84  {
86 }

◆ CountAlternators()

void tesseract::FullyConnected::CountAlternators ( const Network other,
double *  same,
double *  changed 
) const
virtual

Reimplemented from tesseract::Network.

Definition at line 291 of file fullyconnected.cpp.

292  {
293  ASSERT_HOST(other.type() == type_);
294  const FullyConnected* fc = static_cast<const FullyConnected*>(&other);
295  weights_.CountAlternators(fc->weights_, same, changed);
296 }
FullyConnected(const STRING &name, int ni, int no, NetworkType type)
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
#define ASSERT_HOST(x)
Definition: errcode.h:84
NetworkType type_
Definition: network.h:285

◆ DebugWeights()

void tesseract::FullyConnected::DebugWeights ( )
virtual

Reimplemented from tesseract::Network.

Definition at line 89 of file fullyconnected.cpp.

89  {
91 }
void Debug2D(const char *msg)
const char * string() const
Definition: strngs.cpp:198

◆ DeSerialize()

bool tesseract::FullyConnected::DeSerialize ( TFile fp)
virtual

Reimplemented from tesseract::Network.

Definition at line 101 of file fullyconnected.cpp.

101  {
102  return weights_.DeSerialize(IsTraining(), fp);
103 }
bool IsTraining() const
Definition: network.h:115
bool DeSerialize(bool training, TFile *fp)

◆ FinishBackward()

void tesseract::FullyConnected::FinishBackward ( const TransposedArray errors_t)

Definition at line 273 of file fullyconnected.cpp.

273  {
274  if (external_source_ == NULL)
275  weights_.SumOuterTransposed(errors_t, source_t_, true);
276  else
277  weights_.SumOuterTransposed(errors_t, *external_source_, true);
278 }
const TransposedArray * external_source_
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)

◆ Forward()

void tesseract::FullyConnected::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 107 of file fullyconnected.cpp.

109  {
110  int width = input.Width();
111  if (type_ == NT_SOFTMAX)
112  output->ResizeFloat(input, no_);
113  else
114  output->Resize(input, no_);
115  SetupForward(input, input_transpose);
117  temp_lines.init_to_size(kNumThreads, NetworkScratch::FloatVec());
119  curr_input.init_to_size(kNumThreads, NetworkScratch::FloatVec());
120  for (int i = 0; i < temp_lines.size(); ++i) {
121  temp_lines[i].Init(no_, scratch);
122  curr_input[i].Init(ni_, scratch);
123  }
124 #ifdef _OPENMP
125 #pragma omp parallel for num_threads(kNumThreads)
126  for (int t = 0; t < width; ++t) {
127  // Thread-local pointer to temporary storage.
128  int thread_id = omp_get_thread_num();
129 #else
130  for (int t = 0; t < width; ++t) {
131  // Thread-local pointer to temporary storage.
132  int thread_id = 0;
133 #endif
134  double* temp_line = temp_lines[thread_id];
135  const double* d_input = NULL;
136  const inT8* i_input = NULL;
137  if (input.int_mode()) {
138  i_input = input.i(t);
139  } else {
140  input.ReadTimeStep(t, curr_input[thread_id]);
141  d_input = curr_input[thread_id];
142  }
143  ForwardTimeStep(d_input, i_input, t, temp_line);
144  output->WriteTimeStep(t, temp_line);
145  if (IsTraining() && type_ != NT_SOFTMAX) {
146  acts_.CopyTimeStepFrom(t, *output, t);
147  }
148  }
149  // Zero all the elements that are in the padding around images that allows
150  // multiple different-sized images to exist in a single array.
151  // acts_ is only used if this is not a softmax op.
152  if (IsTraining() && type_ != NT_SOFTMAX) {
154  }
155  output->ZeroInvalidElements();
156 #if DEBUG_DETAIL > 0
157  tprintf("F Output:%s\n", name_.string());
158  output->Print(10);
159 #endif
160  if (debug) DisplayForward(*output);
161 }
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
Definition: networkio.cpp:383
void ZeroInvalidElements()
Definition: networkio.cpp:88
void init_to_size(int size, T t)
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:285
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
bool IsTraining() const
Definition: network.h:115
int size() const
Definition: genericvector.h:72
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
int8_t inT8
Definition: host.h:34
const int kNumThreads
NetworkType type_
Definition: network.h:285
void ForwardTimeStep(const double *d_input, const inT8 *i_input, int t, double *output_line)

◆ ForwardTimeStep()

void tesseract::FullyConnected::ForwardTimeStep ( const double *  d_input,
const inT8 i_input,
int  t,
double *  output_line 
)

Definition at line 176 of file fullyconnected.cpp.

177  {
178  // input is copied to source_ line-by-line for cache coherency.
179  if (IsTraining() && external_source_ == NULL && d_input != NULL)
180  source_t_.WriteStrided(t, d_input);
181  if (d_input != NULL)
182  weights_.MatrixDotVector(d_input, output_line);
183  else
184  weights_.MatrixDotVector(i_input, output_line);
185  if (type_ == NT_TANH) {
186  FuncInplace<GFunc>(no_, output_line);
187  } else if (type_ == NT_LOGISTIC) {
188  FuncInplace<FFunc>(no_, output_line);
189  } else if (type_ == NT_POSCLIP) {
190  FuncInplace<ClipFFunc>(no_, output_line);
191  } else if (type_ == NT_SYMCLIP) {
192  FuncInplace<ClipGFunc>(no_, output_line);
193  } else if (type_ == NT_RELU) {
194  FuncInplace<Relu>(no_, output_line);
195  } else if (type_ == NT_SOFTMAX || type_ == NT_SOFTMAX_NO_CTC) {
196  SoftmaxInPlace(no_, output_line);
197  } else if (type_ != NT_LINEAR) {
198  ASSERT_HOST("Invalid fully-connected type!" == NULL);
199  }
200 }
void MatrixDotVector(const double *u, double *v) const
void WriteStrided(int t, const float *data)
Definition: weightmatrix.h:37
bool IsTraining() const
Definition: network.h:115
#define ASSERT_HOST(x)
Definition: errcode.h:84
void SoftmaxInPlace(int n, T *inout)
Definition: functions.h:163
NetworkType type_
Definition: network.h:285
const TransposedArray * external_source_

◆ InitWeights()

int tesseract::FullyConnected::InitWeights ( float  range,
TRand randomizer 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 76 of file fullyconnected.cpp.

76  {
77  Network::SetRandomizer(randomizer);
79  range, randomizer);
80  return num_weights_;
81 }
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:140
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
int InitWeightsFloat(int no, int ni, bool ada_grad, float weight_range, TRand *randomizer)
inT32 num_weights_
Definition: network.h:291

◆ OutputShape()

StaticShape tesseract::FullyConnected::OutputShape ( const StaticShape input_shape) const
virtual

Reimplemented from tesseract::Network.

Definition at line 45 of file fullyconnected.cpp.

45  {
46  LossType loss_type = LT_NONE;
47  if (type_ == NT_SOFTMAX)
48  loss_type = LT_CTC;
49  else if (type_ == NT_SOFTMAX_NO_CTC)
50  loss_type = LT_SOFTMAX;
51  else if (type_ == NT_LOGISTIC)
52  loss_type = LT_LOGISTIC;
53  StaticShape result(input_shape);
54  result.set_depth(no_);
55  result.set_loss_type(loss_type);
56  return result;
57 }
NetworkType type_
Definition: network.h:285

◆ Serialize()

bool tesseract::FullyConnected::Serialize ( TFile fp) const
virtual

Reimplemented from tesseract::Network.

Definition at line 94 of file fullyconnected.cpp.

94  {
95  if (!Network::Serialize(fp)) return false;
96  if (!weights_.Serialize(IsTraining(), fp)) return false;
97  return true;
98 }
bool IsTraining() const
Definition: network.h:115
bool Serialize(bool training, TFile *fp) const
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153

◆ SetEnableTraining()

void tesseract::FullyConnected::SetEnableTraining ( TrainingState  state)
virtual

Reimplemented from tesseract::Network.

Definition at line 60 of file fullyconnected.cpp.

60  {
61  if (state == TS_RE_ENABLE) {
62  // Enable only from temp disabled.
64  } else if (state == TS_TEMP_DISABLE) {
65  // Temp disable only from enabled.
66  if (training_ == TS_ENABLED) training_ = state;
67  } else {
68  if (state == TS_ENABLED && training_ != TS_ENABLED)
70  training_ = state;
71  }
72 }
TrainingState training_
Definition: network.h:286

◆ SetupForward()

void tesseract::FullyConnected::SetupForward ( const NetworkIO input,
const TransposedArray input_transpose 
)

Definition at line 164 of file fullyconnected.cpp.

165  {
166  // Softmax output is always float, so save the input type.
167  int_mode_ = input.int_mode();
168  if (IsTraining()) {
169  acts_.Resize(input, no_);
170  // Source_ is a transposed copy of input. It isn't needed if provided.
171  external_source_ = input_transpose;
172  if (external_source_ == NULL) source_t_.ResizeNoInit(ni_, input.Width());
173  }
174 }
void Resize(const NetworkIO &src, int num_features)
Definition: networkio.h:45
bool IsTraining() const
Definition: network.h:115
void ResizeNoInit(int size1, int size2)
Definition: matrix.h:86
const TransposedArray * external_source_

◆ spec()

virtual STRING tesseract::FullyConnected::spec ( ) const
inlinevirtual

Reimplemented from tesseract::Network.

Definition at line 37 of file fullyconnected.h.

37  {
38  STRING spec;
39  if (type_ == NT_TANH)
40  spec.add_str_int("Ft", no_);
41  else if (type_ == NT_LOGISTIC)
42  spec.add_str_int("Fs", no_);
43  else if (type_ == NT_RELU)
44  spec.add_str_int("Fr", no_);
45  else if (type_ == NT_LINEAR)
46  spec.add_str_int("Fl", no_);
47  else if (type_ == NT_POSCLIP)
48  spec.add_str_int("Fp", no_);
49  else if (type_ == NT_SYMCLIP)
50  spec.add_str_int("Fs", no_);
51  else if (type_ == NT_SOFTMAX)
52  spec.add_str_int("Fc", no_);
53  else
54  spec.add_str_int("Fm", no_);
55  return spec;
56  }
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
virtual STRING spec() const
Definition: strngs.h:45
NetworkType type_
Definition: network.h:285

◆ Update()

void tesseract::FullyConnected::Update ( float  learning_rate,
float  momentum,
int  num_samples 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 283 of file fullyconnected.cpp.

284  {
285  weights_.Update(learning_rate, momentum, num_samples);
286 }
void Update(double learning_rate, double momentum, int num_samples)

Member Data Documentation

◆ acts_

NetworkIO tesseract::FullyConnected::acts_
protected

Definition at line 123 of file fullyconnected.h.

◆ external_source_

const TransposedArray* tesseract::FullyConnected::external_source_
protected

Definition at line 121 of file fullyconnected.h.

◆ int_mode_

bool tesseract::FullyConnected::int_mode_
protected

Definition at line 126 of file fullyconnected.h.

◆ source_t_

TransposedArray tesseract::FullyConnected::source_t_
protected

Definition at line 118 of file fullyconnected.h.

◆ weights_

WeightMatrix tesseract::FullyConnected::weights_
protected

Definition at line 116 of file fullyconnected.h.


The documentation for this class was generated from the following files: