tesseract  4.00.00dev
tesseract::LSTM Class Reference

#include <lstm.h>

Inheritance diagram for tesseract::LSTM:
tesseract::Network

Public Types

enum  WeightType {
  CI, GI, GF1, GO,
  GFS, WT_COUNT
}
 

Public Member Functions

 LSTM (const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
 
virtual ~LSTM ()
 
virtual StaticShape OutputShape (const StaticShape &input_shape) const
 
virtual STRING spec () const
 
virtual void SetEnableTraining (TrainingState state)
 
virtual int InitWeights (float range, TRand *randomizer)
 
virtual void ConvertToInt ()
 
virtual void DebugWeights ()
 
virtual bool Serialize (TFile *fp) const
 
virtual bool DeSerialize (TFile *fp)
 
virtual void Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
 
virtual bool Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
 
virtual void Update (float learning_rate, float momentum, int num_samples)
 
virtual void CountAlternators (const Network &other, double *same, double *changed) const
 
void PrintW ()
 
void PrintDW ()
 
bool Is2D () const
 
- Public Member Functions inherited from tesseract::Network
 Network ()
 
 Network (NetworkType type, const STRING &name, int ni, int no)
 
virtual ~Network ()
 
NetworkType type () const
 
bool IsTraining () const
 
bool needs_to_backprop () const
 
int num_weights () const
 
int NumInputs () const
 
int NumOutputs () const
 
virtual StaticShape InputShape () const
 
const STRINGname () const
 
bool TestFlag (NetworkFlags flag) const
 
virtual bool IsPlumbingType () const
 
virtual void SetNetworkFlags (uinT32 flags)
 
virtual void SetRandomizer (TRand *randomizer)
 
virtual bool SetupNeedsBackprop (bool needs_backprop)
 
virtual int XScaleFactor () const
 
virtual void CacheXScaleFactor (int factor)
 
void DisplayForward (const NetworkIO &matrix)
 
void DisplayBackward (const NetworkIO &matrix)
 

Additional Inherited Members

- Static Public Member Functions inherited from tesseract::Network
static NetworkCreateFromFile (TFile *fp)
 
static void ClearWindow (bool tess_coords, const char *window_name, int width, int height, ScrollView **window)
 
static int DisplayImage (Pix *pix, ScrollView *window)
 
- Protected Member Functions inherited from tesseract::Network
double Random (double range)
 
- Protected Attributes inherited from tesseract::Network
NetworkType type_
 
TrainingState training_
 
bool needs_to_backprop_
 
inT32 network_flags_
 
inT32 ni_
 
inT32 no_
 
inT32 num_weights_
 
STRING name_
 
ScrollViewforward_win_
 
ScrollViewbackward_win_
 
TRandrandomizer_
 
- Static Protected Attributes inherited from tesseract::Network
static char const *const kTypeNames [NT_COUNT]
 

Detailed Description

Definition at line 28 of file lstm.h.

Member Enumeration Documentation

◆ WeightType

Enumerator
CI 
GI 
GF1 
GO 
GFS 
WT_COUNT 

Definition at line 33 of file lstm.h.

33  {
34  CI, // Cell Inputs.
35  GI, // Gate at the input.
36  GF1, // Forget gate at the memory (1-d or looking back 1 timestep).
37  GO, // Gate at the output.
38  GFS, // Forget gate at the memory, looking back in the other dimension.
39 
40  WT_COUNT // Number of WeightTypes.
41  };

Constructor & Destructor Documentation

◆ LSTM()

tesseract::LSTM::LSTM ( const STRING name,
int  num_inputs,
int  num_states,
int  num_outputs,
bool  two_dimensional,
NetworkType  type 
)

Definition at line 70 of file lstm.cpp.

72  : Network(type, name, ni, no),
73  na_(ni + ns),
74  ns_(ns),
75  nf_(0),
76  is_2d_(two_dimensional),
77  softmax_(NULL),
78  input_width_(0) {
79  if (two_dimensional) na_ += ns_;
80  if (type_ == NT_LSTM || type_ == NT_LSTM_SUMMARY) {
81  nf_ = 0;
82  // networkbuilder ensures this is always true.
83  ASSERT_HOST(no == ns);
84  } else if (type_ == NT_LSTM_SOFTMAX || type_ == NT_LSTM_SOFTMAX_ENCODED) {
85  nf_ = type_ == NT_LSTM_SOFTMAX ? no_ : IntCastRounded(ceil(log2(no_)));
86  softmax_ = new FullyConnected("LSTM Softmax", ns_, no_, NT_SOFTMAX);
87  } else {
88  tprintf("%d is invalid type of LSTM!\n", type);
89  ASSERT_HOST(false);
90  }
91  na_ += nf_;
92 }
NetworkType type() const
Definition: network.h:112
#define tprintf(...)
Definition: tprintf.h:31
int IntCastRounded(double x)
Definition: helpers.h:179
#define ASSERT_HOST(x)
Definition: errcode.h:84
NetworkType type_
Definition: network.h:285

◆ ~LSTM()

tesseract::LSTM::~LSTM ( )
virtual

Definition at line 94 of file lstm.cpp.

94 { delete softmax_; }

Member Function Documentation

◆ Backward()

bool tesseract::LSTM::Backward ( bool  debug,
const NetworkIO fwd_deltas,
NetworkScratch scratch,
NetworkIO back_deltas 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 400 of file lstm.cpp.

402  {
403  if (debug) DisplayBackward(fwd_deltas);
404  back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_, ni_);
405  // ======Scratch space.======
406  // Output errors from deltas with recurrence from sourceerr.
407  NetworkScratch::FloatVec outputerr;
408  outputerr.Init(ns_, scratch);
409  // Recurrent error in the state/source.
410  NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
411  curr_stateerr.Init(ns_, scratch);
412  curr_sourceerr.Init(na_, scratch);
413  ZeroVector<double>(ns_, curr_stateerr);
414  ZeroVector<double>(na_, curr_sourceerr);
415  // Errors in the gates.
416  NetworkScratch::FloatVec gate_errors[WT_COUNT];
417  for (int g = 0; g < WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
418  // Rotating buffers of width buf_width allow storage of the recurrent time-
419  // steps used only for true 2-D. Stores one full strip of the major direction.
420  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
421  GenericVector<NetworkScratch::FloatVec> stateerr, sourceerr;
422  if (Is2D()) {
423  stateerr.init_to_size(buf_width, NetworkScratch::FloatVec());
424  sourceerr.init_to_size(buf_width, NetworkScratch::FloatVec());
425  for (int t = 0; t < buf_width; ++t) {
426  stateerr[t].Init(ns_, scratch);
427  sourceerr[t].Init(na_, scratch);
428  ZeroVector<double>(ns_, stateerr[t]);
429  ZeroVector<double>(na_, sourceerr[t]);
430  }
431  }
432  // Parallel-generated sourceerr from each of the gates.
433  NetworkScratch::FloatVec sourceerr_temps[WT_COUNT];
434  for (int w = 0; w < WT_COUNT; ++w)
435  sourceerr_temps[w].Init(na_, scratch);
436  int width = input_width_;
437  // Transposed gate errors stored over all timesteps for sum outer.
438  NetworkScratch::GradientStore gate_errors_t[WT_COUNT];
439  for (int w = 0; w < WT_COUNT; ++w) {
440  gate_errors_t[w].Init(ns_, width, scratch);
441  }
442  // Used only if softmax_ != NULL.
443  NetworkScratch::FloatVec softmax_errors;
444  NetworkScratch::GradientStore softmax_errors_t;
445  if (softmax_ != NULL) {
446  softmax_errors.Init(no_, scratch);
447  softmax_errors_t.Init(no_, width, scratch);
448  }
449  double state_clip = Is2D() ? 9.0 : 4.0;
450 #if DEBUG_DETAIL > 1
451  tprintf("fwd_deltas:%s\n", name_.string());
452  fwd_deltas.Print(10);
453 #endif
454  StrideMap::Index dest_index(input_map_);
455  dest_index.InitToLast();
456  // Used only by NT_LSTM_SUMMARY.
457  StrideMap::Index src_index(fwd_deltas.stride_map());
458  src_index.InitToLast();
459  do {
460  int t = dest_index.t();
461  bool at_last_x = dest_index.IsLast(FD_WIDTH);
462  // up_pos is the 2-D back step, down_pos is the 2-D fwd step, and are only
463  // valid if >= 0, which is true if 2d and not on the top/bottom.
464  int up_pos = -1;
465  int down_pos = -1;
466  if (Is2D()) {
467  if (dest_index.index(FD_HEIGHT) > 0) {
468  StrideMap::Index up_index(dest_index);
469  if (up_index.AddOffset(-1, FD_HEIGHT)) up_pos = up_index.t();
470  }
471  if (!dest_index.IsLast(FD_HEIGHT)) {
472  StrideMap::Index down_index(dest_index);
473  if (down_index.AddOffset(1, FD_HEIGHT)) down_pos = down_index.t();
474  }
475  }
476  // Index of the 2-D revolving buffers (sourceerr, stateerr).
477  int mod_t = Modulo(t, buf_width); // Current timestep.
478  // Zero the state in the major direction only at the end of every row.
479  if (at_last_x) {
480  ZeroVector<double>(na_, curr_sourceerr);
481  ZeroVector<double>(ns_, curr_stateerr);
482  }
483  // Setup the outputerr.
484  if (type_ == NT_LSTM_SUMMARY) {
485  if (dest_index.IsLast(FD_WIDTH)) {
486  fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
487  src_index.Decrement();
488  } else {
489  ZeroVector<double>(ns_, outputerr);
490  }
491  } else if (softmax_ == NULL) {
492  fwd_deltas.ReadTimeStep(t, outputerr);
493  } else {
494  softmax_->BackwardTimeStep(fwd_deltas, t, softmax_errors,
495  softmax_errors_t.get(), outputerr);
496  }
497  if (!at_last_x)
498  AccumulateVector(ns_, curr_sourceerr + ni_ + nf_, outputerr);
499  if (down_pos >= 0)
500  AccumulateVector(ns_, sourceerr[mod_t] + ni_ + nf_ + ns_, outputerr);
501  // Apply the 1-d forget gates.
502  if (!at_last_x) {
503  const float* next_node_gf1 = node_values_[GF1].f(t + 1);
504  for (int i = 0; i < ns_; ++i) {
505  curr_stateerr[i] *= next_node_gf1[i];
506  }
507  }
508  if (Is2D() && t + 1 < width) {
509  for (int i = 0; i < ns_; ++i) {
510  if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
511  }
512  if (down_pos >= 0) {
513  const float* right_node_gfs = node_values_[GFS].f(down_pos);
514  const double* right_stateerr = stateerr[mod_t];
515  for (int i = 0; i < ns_; ++i) {
516  if (which_fg_[down_pos][i] == 2) {
517  curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
518  }
519  }
520  }
521  }
522  state_.FuncMultiply3Add<HPrime>(node_values_[GO], t, outputerr,
523  curr_stateerr);
524  // Clip stateerr_ to a sane range.
525  ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
526 #if DEBUG_DETAIL > 1
527  if (t + 10 > width) {
528  tprintf("t=%d, stateerr=", t);
529  for (int i = 0; i < ns_; ++i)
530  tprintf(" %g,%g,%g", curr_stateerr[i], outputerr[i],
531  curr_sourceerr[ni_ + nf_ + i]);
532  tprintf("\n");
533  }
534 #endif
535  // Matrix multiply to get the source errors.
537 
538  // Cell inputs.
539  node_values_[CI].FuncMultiply3<GPrime>(t, node_values_[GI], t,
540  curr_stateerr, gate_errors[CI]);
541  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[CI].get());
542  gate_weights_[CI].VectorDotMatrix(gate_errors[CI], sourceerr_temps[CI]);
543  gate_errors_t[CI].get()->WriteStrided(t, gate_errors[CI]);
544 
546  // Input Gates.
547  node_values_[GI].FuncMultiply3<FPrime>(t, node_values_[CI], t,
548  curr_stateerr, gate_errors[GI]);
549  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GI].get());
550  gate_weights_[GI].VectorDotMatrix(gate_errors[GI], sourceerr_temps[GI]);
551  gate_errors_t[GI].get()->WriteStrided(t, gate_errors[GI]);
552 
554  // 1-D forget Gates.
555  if (t > 0) {
556  node_values_[GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
557  gate_errors[GF1]);
558  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GF1].get());
559  gate_weights_[GF1].VectorDotMatrix(gate_errors[GF1],
560  sourceerr_temps[GF1]);
561  } else {
562  memset(gate_errors[GF1], 0, ns_ * sizeof(gate_errors[GF1][0]));
563  memset(sourceerr_temps[GF1], 0, na_ * sizeof(*sourceerr_temps[GF1]));
564  }
565  gate_errors_t[GF1].get()->WriteStrided(t, gate_errors[GF1]);
566 
567  // 2-D forget Gates.
568  if (up_pos >= 0) {
569  node_values_[GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
570  gate_errors[GFS]);
571  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GFS].get());
572  gate_weights_[GFS].VectorDotMatrix(gate_errors[GFS],
573  sourceerr_temps[GFS]);
574  } else {
575  memset(gate_errors[GFS], 0, ns_ * sizeof(gate_errors[GFS][0]));
576  memset(sourceerr_temps[GFS], 0, na_ * sizeof(*sourceerr_temps[GFS]));
577  }
578  if (Is2D()) gate_errors_t[GFS].get()->WriteStrided(t, gate_errors[GFS]);
579 
581  // Output gates.
582  state_.Func2Multiply3<HFunc, FPrime>(node_values_[GO], t, outputerr,
583  gate_errors[GO]);
584  ClipVector(ns_, -kErrClip, kErrClip, gate_errors[GO].get());
585  gate_weights_[GO].VectorDotMatrix(gate_errors[GO], sourceerr_temps[GO]);
586  gate_errors_t[GO].get()->WriteStrided(t, gate_errors[GO]);
588 
589  SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
590  sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
591  curr_sourceerr);
592  back_deltas->WriteTimeStep(t, curr_sourceerr);
593  // Save states for use by the 2nd dimension only if needed.
594  if (Is2D()) {
595  CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
596  CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
597  }
598  } while (dest_index.Decrement());
599 #if DEBUG_DETAIL > 2
600  for (int w = 0; w < WT_COUNT; ++w) {
601  tprintf("%s gate errors[%d]\n", name_.string(), w);
602  gate_errors_t[w].get()->PrintUnTransposed(10);
603  }
604 #endif
605  // Transposed source_ used to speed-up SumOuter.
606  NetworkScratch::GradientStore source_t, state_t;
607  source_t.Init(na_, width, scratch);
608  source_.Transpose(source_t.get());
609  state_t.Init(ns_, width, scratch);
610  state_.Transpose(state_t.get());
611 #ifdef _OPENMP
612 #pragma omp parallel for num_threads(GFS) if (!Is2D())
613 #endif
614  for (int w = 0; w < WT_COUNT; ++w) {
615  if (w == GFS && !Is2D()) continue;
616  gate_weights_[w].SumOuterTransposed(*gate_errors_t[w], *source_t, false);
617  }
618  if (softmax_ != NULL) {
619  softmax_->FinishBackward(*softmax_errors_t);
620  }
621  if (needs_to_backprop_) {
622  // Normalize the inputerr in back_deltas.
623  back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas);
624  return true;
625  }
626  return false;
627 }
bool needs_to_backprop_
Definition: network.h:287
bool Is2D() const
Definition: lstm.h:117
void AccumulateVector(int n, const double *src, double *dest)
Definition: functions.h:191
float * f(int t)
Definition: networkio.h:115
void init_to_size(int size, T t)
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:186
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
void VectorDotMatrix(const double *u, double *v) const
#define SECTION_IF_OPENMP
Definition: lstm.cpp:57
int Size(FlexDimensions dimension) const
Definition: stridemap.h:116
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
Definition: functions.h:209
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:56
const double kErrClip
Definition: lstm.cpp:68
NetworkType type_
Definition: network.h:285
void Transpose(TransposedArray *dest) const
Definition: networkio.cpp:964
void ClipVector(int n, T lower, T upper, T *vec)
Definition: functions.h:225
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:58
void DisplayBackward(const NetworkIO &matrix)
Definition: network.cpp:296
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
Definition: networkio.h:299
void FinishBackward(const TransposedArray &errors_t)
int Modulo(int a, int b)
Definition: helpers.h:164
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)

◆ ConvertToInt()

void tesseract::LSTM::ConvertToInt ( )
virtual

Reimplemented from tesseract::Network.

Definition at line 144 of file lstm.cpp.

144  {
145  for (int w = 0; w < WT_COUNT; ++w) {
146  if (w == GFS && !Is2D()) continue;
147  gate_weights_[w].ConvertToInt();
148  }
149  if (softmax_ != NULL) {
150  softmax_->ConvertToInt();
151  }
152 }
bool Is2D() const
Definition: lstm.h:117

◆ CountAlternators()

void tesseract::LSTM::CountAlternators ( const Network other,
double *  same,
double *  changed 
) const
virtual

Reimplemented from tesseract::Network.

Definition at line 651 of file lstm.cpp.

652  {
653  ASSERT_HOST(other.type() == type_);
654  const LSTM* lstm = static_cast<const LSTM*>(&other);
655  for (int w = 0; w < WT_COUNT; ++w) {
656  if (w == GFS && !Is2D()) continue;
657  gate_weights_[w].CountAlternators(lstm->gate_weights_[w], same, changed);
658  }
659  if (softmax_ != NULL) {
660  softmax_->CountAlternators(*lstm->softmax_, same, changed);
661  }
662 }
bool Is2D() const
Definition: lstm.h:117
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
#define ASSERT_HOST(x)
Definition: errcode.h:84
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
Definition: lstm.cpp:70
NetworkType type_
Definition: network.h:285
virtual void CountAlternators(const Network &other, double *same, double *changed) const

◆ DebugWeights()

void tesseract::LSTM::DebugWeights ( )
virtual

Reimplemented from tesseract::Network.

Definition at line 155 of file lstm.cpp.

155  {
156  for (int w = 0; w < WT_COUNT; ++w) {
157  if (w == GFS && !Is2D()) continue;
158  STRING msg = name_;
159  msg.add_str_int(" Gate weights ", w);
160  gate_weights_[w].Debug2D(msg.string());
161  }
162  if (softmax_ != NULL) {
163  softmax_->DebugWeights();
164  }
165 }
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
bool Is2D() const
Definition: lstm.h:117
void Debug2D(const char *msg)
const char * string() const
Definition: strngs.cpp:198
Definition: strngs.h:45

◆ DeSerialize()

bool tesseract::LSTM::DeSerialize ( TFile fp)
virtual

Reimplemented from tesseract::Network.

Definition at line 181 of file lstm.cpp.

181  {
182  if (fp->FReadEndian(&na_, sizeof(na_), 1) != 1) return false;
183  if (type_ == NT_LSTM_SOFTMAX) {
184  nf_ = no_;
185  } else if (type_ == NT_LSTM_SOFTMAX_ENCODED) {
186  nf_ = IntCastRounded(ceil(log2(no_)));
187  } else {
188  nf_ = 0;
189  }
190  is_2d_ = false;
191  for (int w = 0; w < WT_COUNT; ++w) {
192  if (w == GFS && !Is2D()) continue;
193  if (!gate_weights_[w].DeSerialize(IsTraining(), fp)) return false;
194  if (w == CI) {
195  ns_ = gate_weights_[CI].NumOutputs();
196  is_2d_ = na_ - nf_ == ni_ + 2 * ns_;
197  }
198  }
199  delete softmax_;
201  softmax_ = static_cast<FullyConnected*>(Network::CreateFromFile(fp));
202  if (softmax_ == nullptr) return false;
203  } else {
204  softmax_ = nullptr;
205  }
206  return true;
207 }
bool Is2D() const
Definition: lstm.h:117
static Network * CreateFromFile(TFile *fp)
Definition: network.cpp:203
bool IsTraining() const
Definition: network.h:115
int IntCastRounded(double x)
Definition: helpers.h:179
virtual bool DeSerialize(TFile *fp)
Definition: lstm.cpp:181
NetworkType type_
Definition: network.h:285

◆ Forward()

void tesseract::LSTM::Forward ( bool  debug,
const NetworkIO input,
const TransposedArray input_transpose,
NetworkScratch scratch,
NetworkIO output 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 211 of file lstm.cpp.

213  {
214  input_map_ = input.stride_map();
215  input_width_ = input.Width();
216  if (softmax_ != NULL)
217  output->ResizeFloat(input, no_);
218  else if (type_ == NT_LSTM_SUMMARY)
219  output->ResizeXTo1(input, no_);
220  else
221  output->Resize(input, no_);
222  ResizeForward(input);
223  // Temporary storage of forward computation for each gate.
224  NetworkScratch::FloatVec temp_lines[WT_COUNT];
225  for (int i = 0; i < WT_COUNT; ++i) temp_lines[i].Init(ns_, scratch);
226  // Single timestep buffers for the current/recurrent output and state.
227  NetworkScratch::FloatVec curr_state, curr_output;
228  curr_state.Init(ns_, scratch);
229  ZeroVector<double>(ns_, curr_state);
230  curr_output.Init(ns_, scratch);
231  ZeroVector<double>(ns_, curr_output);
232  // Rotating buffers of width buf_width allow storage of the state and output
233  // for the other dimension, used only when working in true 2D mode. The width
234  // is enough to hold an entire strip of the major direction.
235  int buf_width = Is2D() ? input_map_.Size(FD_WIDTH) : 1;
237  if (Is2D()) {
238  states.init_to_size(buf_width, NetworkScratch::FloatVec());
239  outputs.init_to_size(buf_width, NetworkScratch::FloatVec());
240  for (int i = 0; i < buf_width; ++i) {
241  states[i].Init(ns_, scratch);
242  ZeroVector<double>(ns_, states[i]);
243  outputs[i].Init(ns_, scratch);
244  ZeroVector<double>(ns_, outputs[i]);
245  }
246  }
247  // Used only if a softmax LSTM.
248  NetworkScratch::FloatVec softmax_output;
249  NetworkScratch::IO int_output;
250  if (softmax_ != NULL) {
251  softmax_output.Init(no_, scratch);
252  ZeroVector<double>(no_, softmax_output);
253  if (input.int_mode()) int_output.Resize2d(true, 1, ns_, scratch);
254  softmax_->SetupForward(input, NULL);
255  }
256  NetworkScratch::FloatVec curr_input;
257  curr_input.Init(na_, scratch);
258  StrideMap::Index src_index(input_map_);
259  // Used only by NT_LSTM_SUMMARY.
260  StrideMap::Index dest_index(output->stride_map());
261  do {
262  int t = src_index.t();
263  // True if there is a valid old state for the 2nd dimension.
264  bool valid_2d = Is2D();
265  if (valid_2d) {
266  StrideMap::Index dim_index(src_index);
267  if (!dim_index.AddOffset(-1, FD_HEIGHT)) valid_2d = false;
268  }
269  // Index of the 2-D revolving buffers (outputs, states).
270  int mod_t = Modulo(t, buf_width); // Current timestep.
271  // Setup the padded input in source.
272  source_.CopyTimeStepGeneral(t, 0, ni_, input, t, 0);
273  if (softmax_ != NULL) {
274  source_.WriteTimeStepPart(t, ni_, nf_, softmax_output);
275  }
276  source_.WriteTimeStepPart(t, ni_ + nf_, ns_, curr_output);
277  if (Is2D())
278  source_.WriteTimeStepPart(t, ni_ + nf_ + ns_, ns_, outputs[mod_t]);
279  if (!source_.int_mode()) source_.ReadTimeStep(t, curr_input);
280  // Matrix multiply the inputs with the source.
282  // It looks inefficient to create the threads on each t iteration, but the
283  // alternative of putting the parallel outside the t loop, a single around
284  // the t-loop and then tasks in place of the sections is a *lot* slower.
285  // Cell inputs.
286  if (source_.int_mode())
287  gate_weights_[CI].MatrixDotVector(source_.i(t), temp_lines[CI]);
288  else
289  gate_weights_[CI].MatrixDotVector(curr_input, temp_lines[CI]);
290  FuncInplace<GFunc>(ns_, temp_lines[CI]);
291 
293  // Input Gates.
294  if (source_.int_mode())
295  gate_weights_[GI].MatrixDotVector(source_.i(t), temp_lines[GI]);
296  else
297  gate_weights_[GI].MatrixDotVector(curr_input, temp_lines[GI]);
298  FuncInplace<FFunc>(ns_, temp_lines[GI]);
299 
301  // 1-D forget gates.
302  if (source_.int_mode())
303  gate_weights_[GF1].MatrixDotVector(source_.i(t), temp_lines[GF1]);
304  else
305  gate_weights_[GF1].MatrixDotVector(curr_input, temp_lines[GF1]);
306  FuncInplace<FFunc>(ns_, temp_lines[GF1]);
307 
308  // 2-D forget gates.
309  if (Is2D()) {
310  if (source_.int_mode())
311  gate_weights_[GFS].MatrixDotVector(source_.i(t), temp_lines[GFS]);
312  else
313  gate_weights_[GFS].MatrixDotVector(curr_input, temp_lines[GFS]);
314  FuncInplace<FFunc>(ns_, temp_lines[GFS]);
315  }
316 
318  // Output gates.
319  if (source_.int_mode())
320  gate_weights_[GO].MatrixDotVector(source_.i(t), temp_lines[GO]);
321  else
322  gate_weights_[GO].MatrixDotVector(curr_input, temp_lines[GO]);
323  FuncInplace<FFunc>(ns_, temp_lines[GO]);
325 
326  // Apply forget gate to state.
327  MultiplyVectorsInPlace(ns_, temp_lines[GF1], curr_state);
328  if (Is2D()) {
329  // Max-pool the forget gates (in 2-d) instead of blindly adding.
330  inT8* which_fg_col = which_fg_[t];
331  memset(which_fg_col, 1, ns_ * sizeof(which_fg_col[0]));
332  if (valid_2d) {
333  const double* stepped_state = states[mod_t];
334  for (int i = 0; i < ns_; ++i) {
335  if (temp_lines[GF1][i] < temp_lines[GFS][i]) {
336  curr_state[i] = temp_lines[GFS][i] * stepped_state[i];
337  which_fg_col[i] = 2;
338  }
339  }
340  }
341  }
342  MultiplyAccumulate(ns_, temp_lines[CI], temp_lines[GI], curr_state);
343  // Clip curr_state to a sane range.
344  ClipVector<double>(ns_, -kStateClip, kStateClip, curr_state);
345  if (IsTraining()) {
346  // Save the gate node values.
347  node_values_[CI].WriteTimeStep(t, temp_lines[CI]);
348  node_values_[GI].WriteTimeStep(t, temp_lines[GI]);
349  node_values_[GF1].WriteTimeStep(t, temp_lines[GF1]);
350  node_values_[GO].WriteTimeStep(t, temp_lines[GO]);
351  if (Is2D()) node_values_[GFS].WriteTimeStep(t, temp_lines[GFS]);
352  }
353  FuncMultiply<HFunc>(curr_state, temp_lines[GO], ns_, curr_output);
354  if (IsTraining()) state_.WriteTimeStep(t, curr_state);
355  if (softmax_ != NULL) {
356  if (input.int_mode()) {
357  int_output->WriteTimeStep(0, curr_output);
358  softmax_->ForwardTimeStep(NULL, int_output->i(0), t, softmax_output);
359  } else {
360  softmax_->ForwardTimeStep(curr_output, NULL, t, softmax_output);
361  }
362  output->WriteTimeStep(t, softmax_output);
364  CodeInBinary(no_, nf_, softmax_output);
365  }
366  } else if (type_ == NT_LSTM_SUMMARY) {
367  // Output only at the end of a row.
368  if (src_index.IsLast(FD_WIDTH)) {
369  output->WriteTimeStep(dest_index.t(), curr_output);
370  dest_index.Increment();
371  }
372  } else {
373  output->WriteTimeStep(t, curr_output);
374  }
375  // Save states for use by the 2nd dimension only if needed.
376  if (Is2D()) {
377  CopyVector(ns_, curr_state, states[mod_t]);
378  CopyVector(ns_, curr_output, outputs[mod_t]);
379  }
380  // Always zero the states at the end of every row, but only for the major
381  // direction. The 2-D state remains intact.
382  if (src_index.IsLast(FD_WIDTH)) {
383  ZeroVector<double>(ns_, curr_state);
384  ZeroVector<double>(ns_, curr_output);
385  }
386  } while (src_index.Increment());
387 #if DEBUG_DETAIL > 0
388  tprintf("Source:%s\n", name_.string());
389  source_.Print(10);
390  tprintf("State:%s\n", name_.string());
391  state_.Print(10);
392  tprintf("Output:%s\n", name_.string());
393  output->Print(10);
394 #endif
395  if (debug) DisplayForward(*output);
396 }
const double kStateClip
Definition: lstm.cpp:66
void ReadTimeStep(int t, double *output) const
Definition: networkio.cpp:598
bool Is2D() const
Definition: lstm.h:117
void init_to_size(int size, T t)
void MatrixDotVector(const double *u, double *v) const
void DisplayForward(const NetworkIO &matrix)
Definition: network.cpp:285
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
Definition: functions.h:201
void CopyVector(int n, const double *src, double *dest)
Definition: functions.h:186
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198
bool IsTraining() const
Definition: network.h:115
#define SECTION_IF_OPENMP
Definition: lstm.cpp:57
bool int_mode() const
Definition: networkio.h:127
int Size(FlexDimensions dimension) const
Definition: stridemap.h:116
void CodeInBinary(int n, int nf, double *vec)
Definition: functions.h:231
#define PARALLEL_IF_OPENMP(__num_threads)
Definition: lstm.cpp:56
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
Definition: networkio.cpp:393
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
int8_t inT8
Definition: host.h:34
NetworkType type_
Definition: network.h:285
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
Definition: functions.h:196
void ForwardTimeStep(const double *d_input, const inT8 *i_input, int t, double *output_line)
void Print(int num) const
Definition: networkio.cpp:366
const inT8 * i(int t) const
Definition: networkio.h:123
void WriteTimeStep(int t, const double *input)
Definition: networkio.cpp:645
#define END_PARALLEL_IF_OPENMP
Definition: lstm.cpp:58
void WriteTimeStepPart(int t, int offset, int num_features, const double *input)
Definition: networkio.cpp:651
int Modulo(int a, int b)
Definition: helpers.h:164

◆ InitWeights()

int tesseract::LSTM::InitWeights ( float  range,
TRand randomizer 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 129 of file lstm.cpp.

129  {
130  Network::SetRandomizer(randomizer);
131  num_weights_ = 0;
132  for (int w = 0; w < WT_COUNT; ++w) {
133  if (w == GFS && !Is2D()) continue;
134  num_weights_ += gate_weights_[w].InitWeightsFloat(
135  ns_, na_ + 1, TestFlag(NF_ADA_GRAD), range, randomizer);
136  }
137  if (softmax_ != NULL) {
138  num_weights_ += softmax_->InitWeights(range, randomizer);
139  }
140  return num_weights_;
141 }
bool Is2D() const
Definition: lstm.h:117
virtual void SetRandomizer(TRand *randomizer)
Definition: network.cpp:140
bool TestFlag(NetworkFlags flag) const
Definition: network.h:144
int InitWeightsFloat(int no, int ni, bool ada_grad, float weight_range, TRand *randomizer)
inT32 num_weights_
Definition: network.h:291
virtual int InitWeights(float range, TRand *randomizer)

◆ Is2D()

bool tesseract::LSTM::Is2D ( ) const
inline

Definition at line 117 of file lstm.h.

117  {
118  return is_2d_;
119  }

◆ OutputShape()

StaticShape tesseract::LSTM::OutputShape ( const StaticShape input_shape) const
virtual

Reimplemented from tesseract::Network.

Definition at line 98 of file lstm.cpp.

98  {
99  StaticShape result = input_shape;
100  result.set_depth(no_);
101  if (type_ == NT_LSTM_SUMMARY) result.set_width(1);
102  if (softmax_ != NULL) return softmax_->OutputShape(result);
103  return result;
104 }
NetworkType type_
Definition: network.h:285
virtual StaticShape OutputShape(const StaticShape &input_shape) const

◆ PrintDW()

void tesseract::LSTM::PrintDW ( )

Definition at line 691 of file lstm.cpp.

691  {
692  tprintf("Delta state:%s\n", name_.string());
693  for (int w = 0; w < WT_COUNT; ++w) {
694  if (w == GFS && !Is2D()) continue;
695  tprintf("Gate %d, inputs\n", w);
696  for (int i = 0; i < ni_; ++i) {
697  tprintf("Row %d:", i);
698  for (int s = 0; s < ns_; ++s)
699  tprintf(" %g", gate_weights_[w].GetDW(s, i));
700  tprintf("\n");
701  }
702  tprintf("Gate %d, outputs\n", w);
703  for (int i = ni_; i < ni_ + ns_; ++i) {
704  tprintf("Row %d:", i - ni_);
705  for (int s = 0; s < ns_; ++s)
706  tprintf(" %g", gate_weights_[w].GetDW(s, i));
707  tprintf("\n");
708  }
709  tprintf("Gate %d, bias\n", w);
710  for (int s = 0; s < ns_; ++s)
711  tprintf(" %g", gate_weights_[w].GetDW(s, na_));
712  tprintf("\n");
713  }
714 }
bool Is2D() const
Definition: lstm.h:117
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198

◆ PrintW()

void tesseract::LSTM::PrintW ( )

Definition at line 665 of file lstm.cpp.

665  {
666  tprintf("Weight state:%s\n", name_.string());
667  for (int w = 0; w < WT_COUNT; ++w) {
668  if (w == GFS && !Is2D()) continue;
669  tprintf("Gate %d, inputs\n", w);
670  for (int i = 0; i < ni_; ++i) {
671  tprintf("Row %d:", i);
672  for (int s = 0; s < ns_; ++s)
673  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
674  tprintf("\n");
675  }
676  tprintf("Gate %d, outputs\n", w);
677  for (int i = ni_; i < ni_ + ns_; ++i) {
678  tprintf("Row %d:", i - ni_);
679  for (int s = 0; s < ns_; ++s)
680  tprintf(" %g", gate_weights_[w].GetWeights(s)[i]);
681  tprintf("\n");
682  }
683  tprintf("Gate %d, bias\n", w);
684  for (int s = 0; s < ns_; ++s)
685  tprintf(" %g", gate_weights_[w].GetWeights(s)[na_]);
686  tprintf("\n");
687  }
688 }
bool Is2D() const
Definition: lstm.h:117
#define tprintf(...)
Definition: tprintf.h:31
const char * string() const
Definition: strngs.cpp:198

◆ Serialize()

bool tesseract::LSTM::Serialize ( TFile fp) const
virtual

Reimplemented from tesseract::Network.

Definition at line 168 of file lstm.cpp.

168  {
169  if (!Network::Serialize(fp)) return false;
170  if (fp->FWrite(&na_, sizeof(na_), 1) != 1) return false;
171  for (int w = 0; w < WT_COUNT; ++w) {
172  if (w == GFS && !Is2D()) continue;
173  if (!gate_weights_[w].Serialize(IsTraining(), fp)) return false;
174  }
175  if (softmax_ != NULL && !softmax_->Serialize(fp)) return false;
176  return true;
177 }
bool Is2D() const
Definition: lstm.h:117
bool IsTraining() const
Definition: network.h:115
virtual bool Serialize(TFile *fp) const
Definition: lstm.cpp:168
virtual bool Serialize(TFile *fp) const
Definition: network.cpp:153
virtual bool Serialize(TFile *fp) const

◆ SetEnableTraining()

void tesseract::LSTM::SetEnableTraining ( TrainingState  state)
virtual

Reimplemented from tesseract::Network.

Definition at line 108 of file lstm.cpp.

108  {
109  if (state == TS_RE_ENABLE) {
110  // Enable only from temp disabled.
112  } else if (state == TS_TEMP_DISABLE) {
113  // Temp disable only from enabled.
114  if (training_ == TS_ENABLED) training_ = state;
115  } else {
116  if (state == TS_ENABLED && training_ != TS_ENABLED) {
117  for (int w = 0; w < WT_COUNT; ++w) {
118  if (w == GFS && !Is2D()) continue;
119  gate_weights_[w].InitBackward();
120  }
121  }
122  training_ = state;
123  }
124  if (softmax_ != NULL) softmax_->SetEnableTraining(state);
125 }
bool Is2D() const
Definition: lstm.h:117
TrainingState training_
Definition: network.h:286
virtual void SetEnableTraining(TrainingState state)

◆ spec()

virtual STRING tesseract::LSTM::spec ( ) const
inlinevirtual

Reimplemented from tesseract::Network.

Definition at line 58 of file lstm.h.

58  {
59  STRING spec;
60  if (type_ == NT_LSTM)
61  spec.add_str_int("Lfx", ns_);
62  else if (type_ == NT_LSTM_SUMMARY)
63  spec.add_str_int("Lfxs", ns_);
64  else if (type_ == NT_LSTM_SOFTMAX)
65  spec.add_str_int("LS", ns_);
66  else if (type_ == NT_LSTM_SOFTMAX_ENCODED)
67  spec.add_str_int("LE", ns_);
68  if (softmax_ != NULL) spec += softmax_->spec();
69  return spec;
70  }
void add_str_int(const char *str, int number)
Definition: strngs.cpp:381
virtual STRING spec() const
virtual STRING spec() const
Definition: lstm.h:58
Definition: strngs.h:45
NetworkType type_
Definition: network.h:285

◆ Update()

void tesseract::LSTM::Update ( float  learning_rate,
float  momentum,
int  num_samples 
)
virtual

Reimplemented from tesseract::Network.

Definition at line 632 of file lstm.cpp.

632  {
633 #if DEBUG_DETAIL > 3
634  PrintW();
635 #endif
636  for (int w = 0; w < WT_COUNT; ++w) {
637  if (w == GFS && !Is2D()) continue;
638  gate_weights_[w].Update(learning_rate, momentum, num_samples);
639  }
640  if (softmax_ != NULL) {
641  softmax_->Update(learning_rate, momentum, num_samples);
642  }
643 #if DEBUG_DETAIL > 3
644  PrintDW();
645 #endif
646 }
bool Is2D() const
Definition: lstm.h:117
virtual void Update(float learning_rate, float momentum, int num_samples)
void Update(double learning_rate, double momentum, int num_samples)
void PrintDW()
Definition: lstm.cpp:691
void PrintW()
Definition: lstm.cpp:665

The documentation for this class was generated from the following files: