tesseract  4.00.00dev
tesseract::Network Class Reference

#include <network.h>

Inheritance diagram for tesseract::Network:
tesseract::Convolve tesseract::FullyConnected tesseract::Input tesseract::LSTM tesseract::Plumbing tesseract::Reconfig tesseract::Parallel tesseract::Reversed tesseract::Series tesseract::Maxpool

Public Member Functions

 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
virtual StaticShape OutputShape (const StaticShape &input_shape) const
 
const STRINGname () const
 
virtual STRING spec () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetEnableTraining (TrainingState state)
 
virtual void SetNetworkFlags (uinT32 flags)
 
virtual int InitWeights (float range, TRand *randomizer)
 
virtual void ConvertToInt ()
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor (int factor)
 
virtual void DebugWeights ()
 
virtual bool Serialize (TFile *fp) const
 
virtual bool DeSerialize (TFile *fp)
 
virtual void Update (float learning_rate, float momentum, int num_samples)
 
virtual void CountAlternators (const Network &other, double *same, double *changed) const
 
virtual void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
 
virtual bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Static Public Member Functions

static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 

Protected Member Functions

double Random (double range)
 

Protected Attributes

NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
inT32 network_flags_
 
inT32 ni_
 
inT32 no_
 
inT32 num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 

Static Protected Attributes

static char const *const kTypeNames [NT_COUNT]
 

Detailed Description

Definition at line 105 of file network.h.

Constructor & Destructor Documentation

◆ Network() [1/2]

tesseract::Network::Network ( )

Definition at line 76 of file network.cpp.

77  : type_(NT_NONE),
79  needs_to_backprop_(true),
80  network_flags_(0),
81  ni_(0),
82  no_(0),
83  num_weights_(0),
84  forward_win_(NULL),
85  backward_win_(NULL),
86  randomizer_(NULL) {}
bool needs_to_backprop_
Definition: network.h:287
ScrollView * backward_win_
Definition: network.h:296
TRand * randomizer_
Definition: network.h:297
inT32 network_flags_
Definition: network.h:288
TrainingState training_
Definition: network.h:286
ScrollView * forward_win_
Definition: network.h:295
NetworkType type_
Definition: network.h:285
inT32 num_weights_
Definition: network.h:291

◆ Network() [2/2]

tesseract::Network::Network ( NetworkType  type,
const STRING name,
int  ni,
int  no 
)

Definition at line 87 of file network.cpp.

88  : type_(type),
90  needs_to_backprop_(true),
91  network_flags_(0),
92  ni_(ni),
93  no_(no),
94  num_weights_(0),
95  name_(name),
96  forward_win_(NULL),
97  backward_win_(NULL),
98  randomizer_(NULL) {}
bool needs_to_backprop_
Definition: network.h:287
ScrollView * backward_win_
Definition: network.h:296
NetworkType type() const
Definition: network.h:112
TRand * randomizer_
Definition: network.h:297
inT32 network_flags_
Definition: network.h:288
TrainingState training_
Definition: network.h:286
ScrollView * forward_win_
Definition: network.h:295
NetworkType type_
Definition: network.h:285
inT32 num_weights_
Definition: network.h:291

◆ ~Network()

tesseract::Network::~Network ( )
virtual

Definition at line 100 of file network.cpp.

100  {
101 }

Member Function Documentation

◆ Backward()

virtual bool tesseract::Network::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
inlinevirtual

Reimplemented in tesseract::LSTM, tesseract::FullyConnected, tesseract::Reversed, tesseract::Input, tesseract::Series, tesseract::Parallel, tesseract::Reconfig, tesseract::Convolve, and tesseract::Maxpool.

Definition at line 259 of file network.h.

261  {
262  tprintf("Must override Network::Backward for type %d\n", type_);
263  return false;
264  }
#define tprintf(...)
Definition: tprintf.h:31
NetworkType type_
Definition: network.h:285

◆ CacheXScaleFactor()

virtual void tesseract::Network::CacheXScaleFactor ( int  factor)
inlinevirtual

Reimplemented in tesseract::Plumbing, tesseract::Input, and tesseract::Series.

Definition at line 201 of file network.h.

201 {}

◆ ClearWindow()

void tesseract::Network::ClearWindow ( bool  tess_coords,
const char *  window_name,
int  width,
int  height,
ScrollView **  window 
)
static

Definition at line 309 of file network.cpp.

310  {
311  if (*window == NULL) {
312  int min_size = MIN(width, height);
313  if (min_size < kMinWinSize) {
314  if (min_size < 1) min_size = 1;
315  width = width * kMinWinSize / min_size;
316  height = height * kMinWinSize / min_size;
317  }
318  width += kXWinFrameSize;
319  height += kYWinFrameSize;
320  if (width > kMaxWinSize) width = kMaxWinSize;
321  if (height > kMaxWinSize) height = kMaxWinSize;
322  *window = new ScrollView(window_name, 80, 100, width, height, width, height,
323  tess_coords);
324  tprintf("Created window %s of size %d, %d\n", window_name, width, height);
325  } else {
326  (*window)->Clear();
327  }
328 }
const int kXWinFrameSize
Definition: network.cpp:53
#define tprintf(...)
Definition: tprintf.h:31
const int kYWinFrameSize
Definition: network.cpp:54
const int kMinWinSize
Definition: network.cpp:50
#define MIN(x, y)
Definition: ndminx.h:28
const int kMaxWinSize
Definition: network.cpp:51

◆ ConvertToInt()

virtual void tesseract::Network::ConvertToInt ( )
inlinevirtual

Reimplemented in tesseract::LSTM, tesseract::FullyConnected, and tesseract::Plumbing.

Definition at line 177 of file network.h.

177 {}

◆ CountAlternators()

virtual void tesseract::Network::CountAlternators ( const Network other,
double *  same,
double *  changed 
) const
inlinevirtual

Reimplemented in tesseract::Plumbing, tesseract::FullyConnected, and tesseract::LSTM.

Definition at line 222 of file network.h.

223  {}

◆ CreateFromFile()

Network * tesseract::Network::CreateFromFile ( TFile fp)
static

Definition at line 203 of file network.cpp.

203  {
204  Network stub;
205  if (!stub.DeSerialize(fp)) return NULL;
206  Network* network = NULL;
207  switch (stub.type_) {
208  case NT_CONVOLVE:
209  network = new Convolve(stub.name_, stub.ni_, 0, 0);
210  break;
211  case NT_INPUT:
212  network = new Input(stub.name_, stub.ni_, stub.no_);
213  break;
214  case NT_LSTM:
215  case NT_LSTM_SOFTMAX:
217  case NT_LSTM_SUMMARY:
218  network =
219  new LSTM(stub.name_, stub.ni_, stub.no_, stub.no_, false, stub.type_);
220  break;
221  case NT_MAXPOOL:
222  network = new Maxpool(stub.name_, stub.ni_, 0, 0);
223  break;
224  // All variants of Parallel.
225  case NT_PARALLEL:
226  case NT_REPLICATED:
227  case NT_PAR_RL_LSTM:
228  case NT_PAR_UD_LSTM:
229  case NT_PAR_2D_LSTM:
230  network = new Parallel(stub.name_, stub.type_);
231  break;
232  case NT_RECONFIG:
233  network = new Reconfig(stub.name_, stub.ni_, 0, 0);
234  break;
235  // All variants of reversed.
236  case NT_XREVERSED:
237  case NT_YREVERSED:
238  case NT_XYTRANSPOSE:
239  network = new Reversed(stub.name_, stub.type_);
240  break;
241  case NT_SERIES:
242  network = new Series(stub.name_);
243  break;
244  case NT_TENSORFLOW:
245 #ifdef INCLUDE_TENSORFLOW
246  network = new TFNetwork(stub.name_);
247 #else
248  tprintf("TensorFlow not compiled in! -DINCLUDE_TENSORFLOW\n");
249  return NULL;
250 #endif
251  break;
252  // All variants of FullyConnected.
253  case NT_SOFTMAX:
254  case NT_SOFTMAX_NO_CTC:
255  case NT_RELU:
256  case NT_TANH:
257  case NT_LINEAR:
258  case NT_LOGISTIC:
259  case NT_POSCLIP:
260  case NT_SYMCLIP:
261  network = new FullyConnected(stub.name_, stub.ni_, stub.no_, stub.type_);
262  break;
263  default:
264  return NULL;
265  }
266  network->training_ = stub.training_;
267  network->needs_to_backprop_ = stub.needs_to_backprop_;
268  network->network_flags_ = stub.network_flags_;
269  network->num_weights_ = stub.num_weights_;
270  if (!network->DeSerialize(fp)) {
271  delete network;
272  return NULL;
273  }
274  return network;
275 }
#define tprintf(...)
Definition: tprintf.h:31

◆ DebugWeights()

virtual void tesseract::Network::DebugWeights ( )
inlinevirtual

Reimplemented in tesseract::Plumbing, tesseract::LSTM, and tesseract::FullyConnected.

Definition at line 204 of file network.h.

204  {
205  tprintf("Must override Network::DebugWeights for type %d\n", type_);
206  }
#define tprintf(...)
Definition: tprintf.h:31
NetworkType type_
Definition: network.h:285

◆ DeSerialize()

bool tesseract::Network::DeSerialize ( TFile fp)
virtual

Reimplemented in tesseract::Plumbing, tesseract::LSTM, tesseract::FullyConnected, tesseract::Reconfig, tesseract::Input, tesseract::Convolve, and tesseract::Maxpool.

Definition at line 172 of file network.cpp.

172  {
173  inT8 data = 0;
174  if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
175  if (data == NT_NONE) {
176  STRING type_name;
177  if (!type_name.DeSerialize(fp)) return false;
178  for (data = 0; data < NT_COUNT && type_name != kTypeNames[data]; ++data) {
179  }
180  if (data == NT_COUNT) {
181  tprintf("Invalid network layer type:%s\n", type_name.string());
182  return false;
183  }
184  }
185  type_ = static_cast<NetworkType>(data);
186  if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
188  if (fp->FRead(&data, sizeof(data), 1) != 1) return false;
189  needs_to_backprop_ = data != 0;
190  if (fp->FReadEndian(&network_flags_, sizeof(network_flags_), 1) != 1)
191  return false;
192  if (fp->FReadEndian(&ni_, sizeof(ni_), 1) != 1) return false;
193  if (fp->FReadEndian(&no_, sizeof(no_), 1) != 1) return false;
194  if (fp->FReadEndian(&num_weights_, sizeof(num_weights_), 1) != 1)
195  return false;
196  if (!name_.DeSerialize(fp)) return false;
197  return true;
198 }
bool needs_to_backprop_
Definition: network.h:287
bool DeSerialize(bool swap, FILE *fp)
Definition: strngs.cpp:163
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
inT32 network_flags_
Definition: network.h:288
TrainingState training_
Definition: network.h:286
Definition: strngs.h:45
static char const *const kTypeNames[NT_COUNT]
Definition: network.h:300
NetworkType
Definition: network.h:43
int8_t inT8
Definition: host.h:34
NetworkType type_
Definition: network.h:285
inT32 num_weights_
Definition: network.h:291

◆ DisplayBackward()

void tesseract::Network::DisplayBackward ( const NetworkIO matrix)

Definition at line 296 of file network.cpp.

296  {
297 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
298  Pix* image = matrix.ToPix();
299  STRING window_name = name_ + "-back";
300  ClearWindow(false, window_name.string(), pixGetWidth(image),
301  pixGetHeight(image), &backward_win_);
302  DisplayImage(image, backward_win_);
304 #endif // GRAPHICS_DISABLED
305 }
ScrollView * backward_win_
Definition: network.h:296
const char * string() const
Definition: strngs.cpp:198
Definition: strngs.h:45
static void Update()
Definition: scrollview.cpp:715
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:309
static int DisplayImage(Pix *pix, ScrollView *window)
Definition: network.cpp:332

◆ DisplayForward()

void tesseract::Network::DisplayForward ( const NetworkIO matrix)

Definition at line 285 of file network.cpp.

285  {
286 #ifndef GRAPHICS_DISABLED // do nothing if there's no graphics
287  Pix* image = matrix.ToPix();
288  ClearWindow(false, name_.string(), pixGetWidth(image),
289  pixGetHeight(image), &forward_win_);
290  DisplayImage(image, forward_win_);
291  forward_win_->Update();
292 #endif // GRAPHICS_DISABLED
293 }
const char * string() const
Definition: strngs.cpp:198
static void Update()
Definition: scrollview.cpp:715
static void ClearWindow(bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
Definition: network.cpp:309
ScrollView * forward_win_
Definition: network.h:295
static int DisplayImage(Pix *pix, ScrollView *window)
Definition: network.cpp:332

◆ DisplayImage()

int tesseract::Network::DisplayImage ( Pix *  pix,
ScrollView window 
)
static

Definition at line 332 of file network.cpp.

332  {
333  int height = pixGetHeight(pix);
334  window->Image(pix, 0, 0);
335  pixDestroy(&pix);
336  return height;
337 }
void Image(struct Pix *image, int x_pos, int y_pos)
Definition: scrollview.cpp:773

◆ Forward()

virtual void tesseract::Network::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
inlinevirtual

Reimplemented in tesseract::LSTM, tesseract::FullyConnected, tesseract::Reversed, tesseract::Input, tesseract::Series, tesseract::Parallel, tesseract::Reconfig, tesseract::Convolve, and tesseract::Maxpool.

Definition at line 248 of file network.h.

250  {
251  tprintf("Must override Network::Forward for type %d\n", type_);
252  }
#define tprintf(...)
Definition: tprintf.h:31
NetworkType type_
Definition: network.h:285

◆ InitWeights()

int tesseract::Network::InitWeights ( float  range,
TRand randomizer 
)
virtual

Reimplemented in tesseract::LSTM, tesseract::FullyConnected, tesseract::Plumbing, and tesseract::Series.

Definition at line 132 of file network.cpp.

132  {
133  randomizer_ = randomizer;
134  return 0;
135 }
TRand * randomizer_
Definition: network.h:297

◆ InputShape()

virtual StaticShape tesseract::Network::InputShape ( ) const
inlinevirtual

Reimplemented in tesseract::Input, and tesseract::Plumbing.

Definition at line 127 of file network.h.

127  {
128  StaticShape result;
129  return result;
130  }

◆ IsPlumbingType()

virtual bool tesseract::Network::IsPlumbingType ( ) const
inlinevirtual

Reimplemented in tesseract::Plumbing.

Definition at line 152 of file network.h.

152 { return false; }

◆ IsTraining()

bool tesseract::Network::IsTraining ( ) const
inline

Definition at line 115 of file network.h.

115 { return training_ == TS_ENABLED; }
TrainingState training_
Definition: network.h:286

◆ name()

const STRING& tesseract::Network::name ( ) const
inline

Definition at line 138 of file network.h.

138  {
139  return name_;
140  }

◆ needs_to_backprop()

bool tesseract::Network::needs_to_backprop ( ) const
inline

Definition at line 116 of file network.h.

116  {
117  return needs_to_backprop_;
118  }
bool needs_to_backprop_
Definition: network.h:287

◆ num_weights()

int tesseract::Network::num_weights ( ) const
inline

Definition at line 119 of file network.h.

119 { return num_weights_; }
inT32 num_weights_
Definition: network.h:291

◆ NumInputs()

int tesseract::Network::NumInputs ( ) const
inline

Definition at line 120 of file network.h.

120  {
121  return ni_;
122  }

◆ NumOutputs()

int tesseract::Network::NumOutputs ( ) const
inline

Definition at line 123 of file network.h.

123  {
124  return no_;
125  }

◆ OutputShape()

virtual StaticShape tesseract::Network::OutputShape ( const StaticShape input_shape) const
inlinevirtual

Reimplemented in tesseract::LSTM, tesseract::Input, tesseract::Reconfig, tesseract::FullyConnected, tesseract::Parallel, tesseract::Reversed, and tesseract::Series.

Definition at line 133 of file network.h.

133  {
134  StaticShape result(input_shape);
135  result.set_depth(no_);
136  return result;
137  }

◆ Random()

double tesseract::Network::Random ( double  range)
protected

Definition at line 278 of file network.cpp.

278  {
279  ASSERT_HOST(randomizer_ != NULL);
280  return randomizer_->SignedRand(range);
281 }
TRand * randomizer_
Definition: network.h:297
#define ASSERT_HOST(x)
Definition: errcode.h:84
double SignedRand(double range)
Definition: helpers.h:60

◆ Serialize()

bool tesseract::Network::Serialize ( TFile fp) const
virtual

Reimplemented in tesseract::Plumbing, tesseract::LSTM, tesseract::FullyConnected, tesseract::Reconfig, tesseract::Input, and tesseract::Convolve.

Definition at line 153 of file network.cpp.

153  {
154  inT8 data = NT_NONE;
155  if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
156  STRING type_name = kTypeNames[type_];
157  if (!type_name.Serialize(fp)) return false;
158  data = training_;
159  if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
160  data = needs_to_backprop_;
161  if (fp->FWrite(&data, sizeof(data), 1) != 1) return false;
162  if (fp->FWrite(&network_flags_, sizeof(network_flags_), 1) != 1) return false;
163  if (fp->FWrite(&ni_, sizeof(ni_), 1) != 1) return false;
164  if (fp->FWrite(&no_, sizeof(no_), 1) != 1) return false;
165  if (fp->FWrite(&num_weights_, sizeof(num_weights_), 1) != 1) return false;
166  if (!name_.Serialize(fp)) return false;
167  return true;
168 }
bool needs_to_backprop_
Definition: network.h:287
inT32 network_flags_
Definition: network.h:288
TrainingState training_
Definition: network.h:286
Definition: strngs.h:45
bool Serialize(FILE *fp) const
Definition: strngs.cpp:148
static char const *const kTypeNames[NT_COUNT]
Definition: network.h:300
int8_t inT8
Definition: host.h:34
NetworkType type_
Definition: network.h:285
inT32 num_weights_
Definition: network.h:291

◆ SetEnableTraining()

void tesseract::Network::SetEnableTraining ( TrainingState  state)
virtual

Reimplemented in tesseract::LSTM, tesseract::FullyConnected, and tesseract::Plumbing.

Definition at line 112 of file network.cpp.

112  {
113  if (state == TS_RE_ENABLE) {
114  // Enable only from temp disabled.
116  } else if (state == TS_TEMP_DISABLE) {
117  // Temp disable only from enabled.
118  if (training_ == TS_ENABLED) training_ = state;
119  } else {
120  training_ = state;
121  }
122 }
TrainingState training_
Definition: network.h:286

◆ SetNetworkFlags()

void tesseract::Network::SetNetworkFlags ( uinT32  flags)
virtual

Reimplemented in tesseract::Plumbing.

Definition at line 126 of file network.cpp.

126  {
127  network_flags_ = flags;
128 }
inT32 network_flags_
Definition: network.h:288

◆ SetRandomizer()

void tesseract::Network::SetRandomizer ( TRand randomizer)
virtual

Reimplemented in tesseract::Plumbing.

Definition at line 140 of file network.cpp.

140  {
141  randomizer_ = randomizer;
142 }
TRand * randomizer_
Definition: network.h:297

◆ SetupNeedsBackprop()

bool tesseract::Network::SetupNeedsBackprop ( bool  needs_backprop)
virtual

Reimplemented in tesseract::Plumbing, and tesseract::Series.

Definition at line 147 of file network.cpp.

147  {
148  needs_to_backprop_ = needs_backprop;
149  return needs_backprop || num_weights_ > 0;
150 }
bool needs_to_backprop_
Definition: network.h:287
inT32 num_weights_
Definition: network.h:291

◆ spec()

virtual STRING tesseract::Network::spec ( ) const
inlinevirtual

◆ TestFlag()

bool tesseract::Network::TestFlag ( NetworkFlags  flag) const
inline

Definition at line 144 of file network.h.

144  {
145  return (network_flags_ & flag) != 0;
146  }
inT32 network_flags_
Definition: network.h:288

◆ type()

NetworkType tesseract::Network::type ( ) const
inline

Definition at line 112 of file network.h.

112  {
113  return type_;
114  }
NetworkType type_
Definition: network.h:285

◆ Update()

virtual void tesseract::Network::Update ( float  learning_rate,
float  momentum,
int  num_samples 
)
inlinevirtual

Reimplemented in tesseract::Plumbing, tesseract::FullyConnected, and tesseract::LSTM.

Definition at line 218 of file network.h.

218 {}

◆ XScaleFactor()

virtual int tesseract::Network::XScaleFactor ( ) const
inlinevirtual

Reimplemented in tesseract::Plumbing, tesseract::Input, tesseract::Series, and tesseract::Reconfig.

Definition at line 195 of file network.h.

195  {
196  return 1;
197  }

Member Data Documentation

◆ backward_win_

ScrollView* tesseract::Network::backward_win_
protected

Definition at line 296 of file network.h.

◆ forward_win_

ScrollView* tesseract::Network::forward_win_
protected

Definition at line 295 of file network.h.

◆ kTypeNames

char const *const tesseract::Network::kTypeNames
staticprotected
Initial value:
= {
"Invalid", "Input",
"Convolve", "Maxpool",
"Parallel", "Replicated",
"ParBidiLSTM", "DepParUDLSTM",
"Par2dLSTM", "Series",
"Reconfig", "RTLReversed",
"TTBReversed", "XYTranspose",
"LSTM", "SummLSTM",
"Logistic", "LinLogistic",
"LinTanh", "Tanh",
"Relu", "Linear",
"Softmax", "SoftmaxNoCTC",
"LSTMSoftmax", "LSTMBinarySoftmax",
"TensorFlow",
}

Definition at line 300 of file network.h.

◆ name_

STRING tesseract::Network::name_
protected

Definition at line 292 of file network.h.

◆ needs_to_backprop_

bool tesseract::Network::needs_to_backprop_
protected

Definition at line 287 of file network.h.

◆ network_flags_

inT32 tesseract::Network::network_flags_
protected

Definition at line 288 of file network.h.

◆ ni_

inT32 tesseract::Network::ni_
protected

Definition at line 289 of file network.h.

◆ no_

inT32 tesseract::Network::no_
protected

Definition at line 290 of file network.h.

◆ num_weights_

inT32 tesseract::Network::num_weights_
protected

Definition at line 291 of file network.h.

◆ randomizer_

TRand* tesseract::Network::randomizer_
protected

Definition at line 297 of file network.h.

◆ training_

TrainingState tesseract::Network::training_
protected

Definition at line 286 of file network.h.

◆ type_

NetworkType tesseract::Network::type_
protected

Definition at line 285 of file network.h.


The documentation for this class was generated from the following files: