34 #define PARALLEL_IF_OPENMP(__num_threads) \ 35 PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \ 36 PRAGMA(omp sections nowait) { \ 38 #define SECTION_IF_OPENMP \ 43 #define END_PARALLEL_IF_OPENMP \ 49 #ifdef _MSC_VER // Different _Pragma 50 #define PRAGMA(x) __pragma(x) 52 #define PRAGMA(x) _Pragma(#x) 56 #define PARALLEL_IF_OPENMP(__num_threads) 57 #define SECTION_IF_OPENMP 58 #define END_PARALLEL_IF_OPENMP 76 is_2d_(two_dimensional),
79 if (two_dimensional) na_ += ns_;
88 tprintf(
"%d is invalid type of LSTM!\n", type);
102 if (softmax_ != NULL)
return softmax_->
OutputShape(result);
117 for (
int w = 0; w <
WT_COUNT; ++w) {
118 if (w ==
GFS && !
Is2D())
continue;
132 for (
int w = 0; w <
WT_COUNT; ++w) {
133 if (w ==
GFS && !
Is2D())
continue;
137 if (softmax_ != NULL) {
145 for (
int w = 0; w <
WT_COUNT; ++w) {
146 if (w ==
GFS && !
Is2D())
continue;
149 if (softmax_ != NULL) {
156 for (
int w = 0; w <
WT_COUNT; ++w) {
157 if (w ==
GFS && !
Is2D())
continue;
162 if (softmax_ != NULL) {
170 if (fp->
FWrite(&na_,
sizeof(na_), 1) != 1)
return false;
171 for (
int w = 0; w <
WT_COUNT; ++w) {
172 if (w ==
GFS && !
Is2D())
continue;
175 if (softmax_ != NULL && !softmax_->
Serialize(fp))
return false;
182 if (fp->
FReadEndian(&na_,
sizeof(na_), 1) != 1)
return false;
191 for (
int w = 0; w <
WT_COUNT; ++w) {
192 if (w ==
GFS && !
Is2D())
continue;
196 is_2d_ = na_ - nf_ ==
ni_ + 2 * ns_;
202 if (softmax_ ==
nullptr)
return false;
215 input_width_ = input.
Width();
216 if (softmax_ != NULL)
222 ResizeForward(input);
225 for (
int i = 0; i <
WT_COUNT; ++i) temp_lines[i].Init(ns_, scratch);
228 curr_state.
Init(ns_, scratch);
229 ZeroVector<double>(ns_, curr_state);
230 curr_output.
Init(ns_, scratch);
231 ZeroVector<double>(ns_, curr_output);
240 for (
int i = 0; i < buf_width; ++i) {
241 states[i].Init(ns_, scratch);
242 ZeroVector<double>(ns_, states[i]);
243 outputs[i].Init(ns_, scratch);
244 ZeroVector<double>(ns_, outputs[i]);
250 if (softmax_ != NULL) {
251 softmax_output.Init(
no_, scratch);
252 ZeroVector<double>(
no_, softmax_output);
257 curr_input.
Init(na_, scratch);
262 int t = src_index.
t();
264 bool valid_2d =
Is2D();
270 int mod_t =
Modulo(t, buf_width);
273 if (softmax_ != NULL) {
287 gate_weights_[
CI].MatrixDotVector(source_.
i(t), temp_lines[
CI]);
290 FuncInplace<GFunc>(ns_, temp_lines[
CI]);
295 gate_weights_[
GI].MatrixDotVector(source_.
i(t), temp_lines[
GI]);
298 FuncInplace<FFunc>(ns_, temp_lines[
GI]);
303 gate_weights_[
GF1].MatrixDotVector(source_.
i(t), temp_lines[
GF1]);
306 FuncInplace<FFunc>(ns_, temp_lines[
GF1]);
311 gate_weights_[
GFS].MatrixDotVector(source_.
i(t), temp_lines[
GFS]);
314 FuncInplace<FFunc>(ns_, temp_lines[
GFS]);
320 gate_weights_[
GO].MatrixDotVector(source_.
i(t), temp_lines[
GO]);
323 FuncInplace<FFunc>(ns_, temp_lines[
GO]);
330 inT8* which_fg_col = which_fg_[t];
331 memset(which_fg_col, 1, ns_ *
sizeof(which_fg_col[0]));
333 const double* stepped_state = states[mod_t];
334 for (
int i = 0; i < ns_; ++i) {
335 if (temp_lines[GF1][i] < temp_lines[
GFS][i]) {
336 curr_state[i] = temp_lines[
GFS][i] * stepped_state[i];
353 FuncMultiply<HFunc>(curr_state, temp_lines[
GO], ns_, curr_output);
355 if (softmax_ != NULL) {
370 dest_index.Increment();
383 ZeroVector<double>(ns_, curr_state);
384 ZeroVector<double>(ns_, curr_output);
408 outputerr.
Init(ns_, scratch);
411 curr_stateerr.
Init(ns_, scratch);
412 curr_sourceerr.
Init(na_, scratch);
413 ZeroVector<double>(ns_, curr_stateerr);
414 ZeroVector<double>(na_, curr_sourceerr);
417 for (
int g = 0; g <
WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
425 for (
int t = 0; t < buf_width; ++t) {
426 stateerr[t].Init(ns_, scratch);
427 sourceerr[t].Init(na_, scratch);
428 ZeroVector<double>(ns_, stateerr[t]);
429 ZeroVector<double>(na_, sourceerr[t]);
435 sourceerr_temps[w].Init(na_, scratch);
436 int width = input_width_;
439 for (
int w = 0; w <
WT_COUNT; ++w) {
440 gate_errors_t[w].
Init(ns_, width, scratch);
445 if (softmax_ != NULL) {
446 softmax_errors.
Init(
no_, scratch);
447 softmax_errors_t.
Init(
no_, width, scratch);
449 double state_clip =
Is2D() ? 9.0 : 4.0;
452 fwd_deltas.
Print(10);
460 int t = dest_index.
t();
477 int mod_t =
Modulo(t, buf_width);
480 ZeroVector<double>(na_, curr_sourceerr);
481 ZeroVector<double>(ns_, curr_stateerr);
487 src_index.Decrement();
489 ZeroVector<double>(ns_, outputerr);
491 }
else if (softmax_ == NULL) {
495 softmax_errors_t.
get(), outputerr);
503 const float* next_node_gf1 = node_values_[
GF1].
f(t + 1);
504 for (
int i = 0; i < ns_; ++i) {
505 curr_stateerr[i] *= next_node_gf1[i];
508 if (
Is2D() && t + 1 < width) {
509 for (
int i = 0; i < ns_; ++i) {
510 if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
513 const float* right_node_gfs = node_values_[
GFS].
f(down_pos);
514 const double* right_stateerr = stateerr[mod_t];
515 for (
int i = 0; i < ns_; ++i) {
516 if (which_fg_[down_pos][i] == 2) {
517 curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
525 ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
527 if (t + 10 > width) {
529 for (
int i = 0; i < ns_; ++i)
530 tprintf(
" %g,%g,%g", curr_stateerr[i], outputerr[i],
531 curr_sourceerr[
ni_ + nf_ + i]);
539 node_values_[
CI].FuncMultiply3<
GPrime>(t, node_values_[
GI], t,
540 curr_stateerr, gate_errors[
CI]);
541 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[
CI].
get());
547 node_values_[
GI].FuncMultiply3<
FPrime>(t, node_values_[
CI], t,
548 curr_stateerr, gate_errors[
GI]);
549 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[
GI].
get());
556 node_values_[
GF1].FuncMultiply3<
FPrime>(t, state_, t - 1, curr_stateerr,
558 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[
GF1].
get());
560 sourceerr_temps[GF1]);
562 memset(gate_errors[
GF1], 0, ns_ *
sizeof(gate_errors[GF1][0]));
563 memset(sourceerr_temps[GF1], 0, na_ *
sizeof(*sourceerr_temps[GF1]));
569 node_values_[
GFS].FuncMultiply3<
FPrime>(t, state_, up_pos, curr_stateerr,
571 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[
GFS].
get());
573 sourceerr_temps[GFS]);
575 memset(gate_errors[
GFS], 0, ns_ *
sizeof(gate_errors[GFS][0]));
576 memset(sourceerr_temps[GFS], 0, na_ *
sizeof(*sourceerr_temps[GFS]));
584 ClipVector(ns_, -kErrClip, kErrClip, gate_errors[
GO].
get());
589 SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
590 sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
595 CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
596 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
600 for (
int w = 0; w <
WT_COUNT; ++w) {
607 source_t.
Init(na_, width, scratch);
609 state_t.
Init(ns_, width, scratch);
612 #pragma omp parallel for num_threads(GFS) if (!Is2D()) 614 for (
int w = 0; w <
WT_COUNT; ++w) {
615 if (w ==
GFS && !
Is2D())
continue;
618 if (softmax_ != NULL) {
632 void LSTM::Update(
float learning_rate,
float momentum,
int num_samples) {
636 for (
int w = 0; w <
WT_COUNT; ++w) {
637 if (w ==
GFS && !
Is2D())
continue;
638 gate_weights_[w].
Update(learning_rate, momentum, num_samples);
640 if (softmax_ != NULL) {
641 softmax_->
Update(learning_rate, momentum, num_samples);
652 double* changed)
const {
654 const LSTM* lstm =
static_cast<const LSTM*
>(&other);
655 for (
int w = 0; w <
WT_COUNT; ++w) {
656 if (w ==
GFS && !
Is2D())
continue;
659 if (softmax_ != NULL) {
667 for (
int w = 0; w <
WT_COUNT; ++w) {
668 if (w ==
GFS && !
Is2D())
continue;
669 tprintf(
"Gate %d, inputs\n", w);
670 for (
int i = 0; i <
ni_; ++i) {
672 for (
int s = 0; s < ns_; ++s)
673 tprintf(
" %g", gate_weights_[w].GetWeights(s)[i]);
676 tprintf(
"Gate %d, outputs\n", w);
677 for (
int i = ni_; i < ni_ + ns_; ++i) {
679 for (
int s = 0; s < ns_; ++s)
680 tprintf(
" %g", gate_weights_[w].GetWeights(s)[i]);
684 for (
int s = 0; s < ns_; ++s)
685 tprintf(
" %g", gate_weights_[w].GetWeights(s)[na_]);
693 for (
int w = 0; w <
WT_COUNT; ++w) {
694 if (w ==
GFS && !
Is2D())
continue;
695 tprintf(
"Gate %d, inputs\n", w);
696 for (
int i = 0; i <
ni_; ++i) {
698 for (
int s = 0; s < ns_; ++s)
699 tprintf(
" %g", gate_weights_[w].GetDW(s, i));
702 tprintf(
"Gate %d, outputs\n", w);
703 for (
int i = ni_; i < ni_ + ns_; ++i) {
705 for (
int s = 0; s < ns_; ++s)
706 tprintf(
" %g", gate_weights_[w].GetDW(s, i));
710 for (
int s = 0; s < ns_; ++s)
711 tprintf(
" %g", gate_weights_[w].GetDW(s, na_));
717 void LSTM::ResizeForward(
const NetworkIO& input) {
718 source_.
Resize(input, na_);
722 for (
int w = 0; w <
WT_COUNT; ++w) {
723 if (w ==
GFS && !
Is2D())
continue;
virtual void CountAlternators(const Network &other, double *same, double *changed) const
void add_str_int(const char *str, int number)
bool AddOffset(int offset, FlexDimensions dimension)
void ReadTimeStep(int t, double *output) const
void Debug2D(const char *msg)
void AccumulateVector(int n, const double *src, double *dest)
virtual void Update(float learning_rate, float momentum, int num_samples)
int index(FlexDimensions dimension) const
virtual void SetRandomizer(TRand *randomizer)
void init_to_size(int size, T t)
void MatrixDotVector(const double *u, double *v) const
void DisplayForward(const NetworkIO &matrix)
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
void CopyVector(int n, const double *src, double *dest)
static Network * CreateFromFile(TFile *fp)
const char * string() const
void Resize(const NetworkIO &src, int num_features)
void VectorDotMatrix(const double *u, double *v) const
void WriteStrided(int t, const float *data)
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
int FReadEndian(void *buffer, int size, int count)
#define SECTION_IF_OPENMP
int IntCastRounded(double x)
virtual void ConvertToInt()
virtual void DebugWeights()
virtual bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas)
int Size(FlexDimensions dimension) const
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
virtual void Update(float learning_rate, float momentum, int num_samples)
void CodeInBinary(int n, int nf, double *vec)
virtual bool DeSerialize(TFile *fp)
#define PARALLEL_IF_OPENMP(__num_threads)
void Init(int size, NetworkScratch *scratch)
TransposedArray * get() const
bool TestFlag(NetworkFlags flag) const
bool IsLast(FlexDimensions dimension) const
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
virtual void SetEnableTraining(TrainingState state)
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
int FWrite(const void *buffer, int size, int count)
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void CopyWithNormalization(const NetworkIO &src, const NetworkIO &scale)
void Update(double learning_rate, double momentum, int num_samples)
void ResizeNoInit(int size1, int size2)
void ResizeFloat(const NetworkIO &src, int num_features)
virtual int InitWeights(float range, TRand *randomizer)
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
int InitWeightsFloat(int no, int ni, bool ada_grad, float weight_range, TRand *randomizer)
void ForwardTimeStep(const double *d_input, const inT8 *i_input, int t, double *output_line)
void ResizeXTo1(const NetworkIO &src, int num_features)
void Print(int num) const
virtual void SetEnableTraining(TrainingState state)
const inT8 * i(int t) const
const StrideMap & stride_map() const
virtual void DebugWeights()
virtual void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output)
void Init(int size1, int size2, NetworkScratch *scratch)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
void WriteTimeStep(int t, const double *input)
void Transpose(TransposedArray *dest) const
void Func2Multiply3(const NetworkIO &v_io, int t, const double *w, double *product) const
virtual bool Serialize(TFile *fp) const
void ClipVector(int n, T lower, T upper, T *vec)
virtual bool Serialize(TFile *fp) const
#define END_PARALLEL_IF_OPENMP
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
void WriteTimeStepPart(int t, int offset, int num_features, const double *input)
virtual void ConvertToInt()
void DisplayBackward(const NetworkIO &matrix)
virtual bool Serialize(TFile *fp) const
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
virtual void CountAlternators(const Network &other, double *same, double *changed) const
void set_width(int value)
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
void FinishBackward(const TransposedArray &errors_t)
virtual int InitWeights(float range, TRand *randomizer)
void PrintUnTransposed(int num)
virtual StaticShape OutputShape(const StaticShape &input_shape) const
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void set_depth(int value)