34 #define PARALLEL_IF_OPENMP(__num_threads)                                  \    35   PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \    36     PRAGMA(omp sections nowait) {                                          \    38 #define SECTION_IF_OPENMP \    43 #define END_PARALLEL_IF_OPENMP \    49 #ifdef _MSC_VER  // Different _Pragma    50 #define PRAGMA(x) __pragma(x)    52 #define PRAGMA(x) _Pragma(#x)    56 #define PARALLEL_IF_OPENMP(__num_threads)    57 #define SECTION_IF_OPENMP    58 #define END_PARALLEL_IF_OPENMP    76       is_2d_(two_dimensional),
    79   if (two_dimensional) na_ += ns_;
    88     tprintf(
"%d is invalid type of LSTM!\n", type);
   102   if (softmax_ != NULL) 
return softmax_->
OutputShape(result);
   117       for (
int w = 0; w < 
WT_COUNT; ++w) {
   118         if (w == 
GFS && !
Is2D()) 
continue;
   132   for (
int w = 0; w < 
WT_COUNT; ++w) {
   133     if (w == 
GFS && !
Is2D()) 
continue;
   137   if (softmax_ != NULL) {
   145   for (
int w = 0; w < 
WT_COUNT; ++w) {
   146     if (w == 
GFS && !
Is2D()) 
continue;
   149   if (softmax_ != NULL) {
   156   for (
int w = 0; w < 
WT_COUNT; ++w) {
   157     if (w == 
GFS && !
Is2D()) 
continue;
   162   if (softmax_ != NULL) {
   170   if (fp->
FWrite(&na_, 
sizeof(na_), 1) != 1) 
return false;
   171   for (
int w = 0; w < 
WT_COUNT; ++w) {
   172     if (w == 
GFS && !
Is2D()) 
continue;
   175   if (softmax_ != NULL && !softmax_->
Serialize(fp)) 
return false;
   182   if (fp->
FReadEndian(&na_, 
sizeof(na_), 1) != 1) 
return false;
   191   for (
int w = 0; w < 
WT_COUNT; ++w) {
   192     if (w == 
GFS && !
Is2D()) 
continue;
   196       is_2d_ = na_ - nf_ == 
ni_ + 2 * ns_;
   202     if (softmax_ == 
nullptr) 
return false;
   215   input_width_ = input.
Width();
   216   if (softmax_ != NULL)
   222   ResizeForward(input);
   225   for (
int i = 0; i < 
WT_COUNT; ++i) temp_lines[i].Init(ns_, scratch);
   228   curr_state.
Init(ns_, scratch);
   229   ZeroVector<double>(ns_, curr_state);
   230   curr_output.
Init(ns_, scratch);
   231   ZeroVector<double>(ns_, curr_output);
   240     for (
int i = 0; i < buf_width; ++i) {
   241       states[i].Init(ns_, scratch);
   242       ZeroVector<double>(ns_, states[i]);
   243       outputs[i].Init(ns_, scratch);
   244       ZeroVector<double>(ns_, outputs[i]);
   250   if (softmax_ != NULL) {
   251     softmax_output.Init(
no_, scratch);
   252     ZeroVector<double>(
no_, softmax_output);
   257   curr_input.
Init(na_, scratch);
   262     int t = src_index.
t();
   264     bool valid_2d = 
Is2D();
   270     int mod_t = 
Modulo(t, buf_width);      
   273     if (softmax_ != NULL) {
   287       gate_weights_[
CI].MatrixDotVector(source_.
i(t), temp_lines[
CI]);
   290     FuncInplace<GFunc>(ns_, temp_lines[
CI]);
   295       gate_weights_[
GI].MatrixDotVector(source_.
i(t), temp_lines[
GI]);
   298     FuncInplace<FFunc>(ns_, temp_lines[
GI]);
   303       gate_weights_[
GF1].MatrixDotVector(source_.
i(t), temp_lines[
GF1]);
   306     FuncInplace<FFunc>(ns_, temp_lines[
GF1]);
   311         gate_weights_[
GFS].MatrixDotVector(source_.
i(t), temp_lines[
GFS]);
   314       FuncInplace<FFunc>(ns_, temp_lines[
GFS]);
   320       gate_weights_[
GO].MatrixDotVector(source_.
i(t), temp_lines[
GO]);
   323     FuncInplace<FFunc>(ns_, temp_lines[
GO]);
   330       inT8* which_fg_col = which_fg_[t];
   331       memset(which_fg_col, 1, ns_ * 
sizeof(which_fg_col[0]));
   333         const double* stepped_state = states[mod_t];
   334         for (
int i = 0; i < ns_; ++i) {
   335           if (temp_lines[GF1][i] < temp_lines[
GFS][i]) {
   336             curr_state[i] = temp_lines[
GFS][i] * stepped_state[i];
   353     FuncMultiply<HFunc>(curr_state, temp_lines[
GO], ns_, curr_output);
   355     if (softmax_ != NULL) {
   370         dest_index.Increment();
   383       ZeroVector<double>(ns_, curr_state);
   384       ZeroVector<double>(ns_, curr_output);
   408   outputerr.
Init(ns_, scratch);
   411   curr_stateerr.
Init(ns_, scratch);
   412   curr_sourceerr.
Init(na_, scratch);
   413   ZeroVector<double>(ns_, curr_stateerr);
   414   ZeroVector<double>(na_, curr_sourceerr);
   417   for (
int g = 0; g < 
WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
   425     for (
int t = 0; t < buf_width; ++t) {
   426       stateerr[t].Init(ns_, scratch);
   427       sourceerr[t].Init(na_, scratch);
   428       ZeroVector<double>(ns_, stateerr[t]);
   429       ZeroVector<double>(na_, sourceerr[t]);
   435     sourceerr_temps[w].Init(na_, scratch);
   436   int width = input_width_;
   439   for (
int w = 0; w < 
WT_COUNT; ++w) {
   440     gate_errors_t[w].
Init(ns_, width, scratch);
   445   if (softmax_ != NULL) {
   446     softmax_errors.
Init(
no_, scratch);
   447     softmax_errors_t.
Init(
no_, width, scratch);
   449   double state_clip = 
Is2D() ? 9.0 : 4.0;
   452   fwd_deltas.
Print(10);
   460     int t = dest_index.
t();
   477     int mod_t = 
Modulo(t, buf_width);      
   480       ZeroVector<double>(na_, curr_sourceerr);
   481       ZeroVector<double>(ns_, curr_stateerr);
   487         src_index.Decrement();
   489         ZeroVector<double>(ns_, outputerr);
   491     } 
else if (softmax_ == NULL) {
   495                                  softmax_errors_t.
get(), outputerr);
   503       const float* next_node_gf1 = node_values_[
GF1].
f(t + 1);
   504       for (
int i = 0; i < ns_; ++i) {
   505         curr_stateerr[i] *= next_node_gf1[i];
   508     if (
Is2D() && t + 1 < width) {
   509       for (
int i = 0; i < ns_; ++i) {
   510         if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
   513         const float* right_node_gfs = node_values_[
GFS].
f(down_pos);
   514         const double* right_stateerr = stateerr[mod_t];
   515         for (
int i = 0; i < ns_; ++i) {
   516           if (which_fg_[down_pos][i] == 2) {
   517             curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
   525     ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
   527     if (t + 10 > width) {
   529       for (
int i = 0; i < ns_; ++i)
   530         tprintf(
" %g,%g,%g", curr_stateerr[i], outputerr[i],
   531                 curr_sourceerr[
ni_ + nf_ + i]);
   539     node_values_[
CI].FuncMultiply3<
GPrime>(t, node_values_[
GI], t,
   540                                            curr_stateerr, gate_errors[
CI]);
   541     ClipVector(ns_, -kErrClip, kErrClip, gate_errors[
CI].
get());
   547     node_values_[
GI].FuncMultiply3<
FPrime>(t, node_values_[
CI], t,
   548                                            curr_stateerr, gate_errors[
GI]);
   549     ClipVector(ns_, -kErrClip, kErrClip, gate_errors[
GI].
get());
   556       node_values_[
GF1].FuncMultiply3<
FPrime>(t, state_, t - 1, curr_stateerr,
   558       ClipVector(ns_, -kErrClip, kErrClip, gate_errors[
GF1].
get());
   560                                          sourceerr_temps[GF1]);
   562       memset(gate_errors[
GF1], 0, ns_ * 
sizeof(gate_errors[GF1][0]));
   563       memset(sourceerr_temps[GF1], 0, na_ * 
sizeof(*sourceerr_temps[GF1]));
   569       node_values_[
GFS].FuncMultiply3<
FPrime>(t, state_, up_pos, curr_stateerr,
   571       ClipVector(ns_, -kErrClip, kErrClip, gate_errors[
GFS].
get());
   573                                          sourceerr_temps[GFS]);
   575       memset(gate_errors[
GFS], 0, ns_ * 
sizeof(gate_errors[GFS][0]));
   576       memset(sourceerr_temps[GFS], 0, na_ * 
sizeof(*sourceerr_temps[GFS]));
   584     ClipVector(ns_, -kErrClip, kErrClip, gate_errors[
GO].
get());
   589     SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
   590                sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
   595       CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
   596       CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
   600   for (
int w = 0; w < 
WT_COUNT; ++w) {
   607   source_t.
Init(na_, width, scratch);
   609   state_t.
Init(ns_, width, scratch);
   612 #pragma omp parallel for num_threads(GFS) if (!Is2D())   614   for (
int w = 0; w < 
WT_COUNT; ++w) {
   615     if (w == 
GFS && !
Is2D()) 
continue;
   618   if (softmax_ != NULL) {
   632 void LSTM::Update(
float learning_rate, 
float momentum, 
int num_samples) {
   636   for (
int w = 0; w < 
WT_COUNT; ++w) {
   637     if (w == 
GFS && !
Is2D()) 
continue;
   638     gate_weights_[w].
Update(learning_rate, momentum, num_samples);
   640   if (softmax_ != NULL) {
   641     softmax_->
Update(learning_rate, momentum, num_samples);
   652                             double* changed)
 const {
   654   const LSTM* lstm = 
static_cast<const LSTM*
>(&other);
   655   for (
int w = 0; w < 
WT_COUNT; ++w) {
   656     if (w == 
GFS && !
Is2D()) 
continue;
   659   if (softmax_ != NULL) {
   667   for (
int w = 0; w < 
WT_COUNT; ++w) {
   668     if (w == 
GFS && !
Is2D()) 
continue;
   669     tprintf(
"Gate %d, inputs\n", w);
   670     for (
int i = 0; i < 
ni_; ++i) {
   672       for (
int s = 0; s < ns_; ++s)
   673         tprintf(
" %g", gate_weights_[w].GetWeights(s)[i]);
   676     tprintf(
"Gate %d, outputs\n", w);
   677     for (
int i = ni_; i < ni_ + ns_; ++i) {
   679       for (
int s = 0; s < ns_; ++s)
   680         tprintf(
" %g", gate_weights_[w].GetWeights(s)[i]);
   684     for (
int s = 0; s < ns_; ++s)
   685       tprintf(
" %g", gate_weights_[w].GetWeights(s)[na_]);
   693   for (
int w = 0; w < 
WT_COUNT; ++w) {
   694     if (w == 
GFS && !
Is2D()) 
continue;
   695     tprintf(
"Gate %d, inputs\n", w);
   696     for (
int i = 0; i < 
ni_; ++i) {
   698       for (
int s = 0; s < ns_; ++s)
   699         tprintf(
" %g", gate_weights_[w].GetDW(s, i));
   702     tprintf(
"Gate %d, outputs\n", w);
   703     for (
int i = ni_; i < ni_ + ns_; ++i) {
   705       for (
int s = 0; s < ns_; ++s)
   706         tprintf(
" %g", gate_weights_[w].GetDW(s, i));
   710     for (
int s = 0; s < ns_; ++s)
   711       tprintf(
" %g", gate_weights_[w].GetDW(s, na_));
   717 void LSTM::ResizeForward(
const NetworkIO& input) {
   718   source_.
Resize(input, na_);
   722     for (
int w = 0; w < 
WT_COUNT; ++w) {
   723       if (w == 
GFS && !
Is2D()) 
continue;
 virtual void CountAlternators(const Network &other, double *same, double *changed) const
void add_str_int(const char *str, int number)
bool AddOffset(int offset, FlexDimensions dimension)
void ReadTimeStep(int t, double *output) const
void Debug2D(const char *msg)
void AccumulateVector(int n, const double *src, double *dest)
virtual void Update(float learning_rate, float momentum, int num_samples)
int index(FlexDimensions dimension) const
virtual void SetRandomizer(TRand *randomizer)
void init_to_size(int size, T t)
void MatrixDotVector(const double *u, double *v) const
void DisplayForward(const NetworkIO &matrix)
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
void CopyVector(int n, const double *src, double *dest)
static Network * CreateFromFile(TFile *fp)
const char * string() const
void Resize(const NetworkIO &src, int num_features)
void VectorDotMatrix(const double *u, double *v) const
void WriteStrided(int t, const float *data)
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
int FReadEndian(void *buffer, int size, int count)
#define SECTION_IF_OPENMP
int IntCastRounded(double x)
virtual void ConvertToInt()
virtual void DebugWeights()
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
int Size(FlexDimensions dimension) const
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
virtual void Update(float learning_rate, float momentum, int num_samples)
void CodeInBinary(int n, int nf, double *vec)
virtual bool DeSerialize(TFile *fp)
#define PARALLEL_IF_OPENMP(__num_threads)
void Init(int size, NetworkScratch *scratch)
TransposedArray * get() const
bool TestFlag(NetworkFlags flag) const
bool IsLast(FlexDimensions dimension) const
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
virtual void SetEnableTraining(TrainingState state)
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
int FWrite(const void *buffer, int size, int count)
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void CopyWithNormalization(const NetworkIO &src, const NetworkIO &scale)
void Update(double learning_rate, double momentum, int num_samples)
void ResizeNoInit(int size1, int size2)
void ResizeFloat(const NetworkIO &src, int num_features)
virtual int InitWeights(float range, TRand *randomizer)
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
int InitWeightsFloat(int no, int ni, bool ada_grad, float weight_range, TRand *randomizer)
void ForwardTimeStep(const double *d_input, const inT8 *i_input, int t, double *output_line)
void ResizeXTo1(const NetworkIO &src, int num_features)
void Print(int num) const
virtual void SetEnableTraining(TrainingState state)
const inT8 * i(int t) const
const StrideMap & stride_map() const
virtual void DebugWeights()
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
void Init(int size1, int size2, NetworkScratch *scratch)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
void WriteTimeStep(int t, const double *input)
void Transpose(TransposedArray *dest) const
void Func2Multiply3(const NetworkIO &v_io, int t, const double *w, double *product) const
virtual bool Serialize(TFile *fp) const
void ClipVector(int n, T lower, T upper, T *vec)
virtual bool Serialize(TFile *fp) const
#define END_PARALLEL_IF_OPENMP
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
void WriteTimeStepPart(int t, int offset, int num_features, const double *input)
virtual void ConvertToInt()
void DisplayBackward(const NetworkIO &matrix)
virtual bool Serialize(TFile *fp) const
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
virtual void CountAlternators(const Network &other, double *same, double *changed) const
void set_width(int value)
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
void FinishBackward(const TransposedArray &errors_t)
virtual int InitWeights(float range, TRand *randomizer)
void PrintUnTransposed(int num)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void set_depth(int value)