404 back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_,
ni_);
407 NetworkScratch::FloatVec outputerr;
408 outputerr.Init(ns_, scratch);
410 NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
411 curr_stateerr.Init(ns_, scratch);
412 curr_sourceerr.Init(na_, scratch);
413 ZeroVector<double>(ns_, curr_stateerr);
414 ZeroVector<double>(na_, curr_sourceerr);
416 NetworkScratch::FloatVec gate_errors[
WT_COUNT];
417 for (
int g = 0; g <
WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
423 stateerr.
init_to_size(buf_width, NetworkScratch::FloatVec());
424 sourceerr.
init_to_size(buf_width, NetworkScratch::FloatVec());
425 for (
int t = 0; t < buf_width; ++t) {
426 stateerr[t].Init(ns_, scratch);
427 sourceerr[t].Init(na_, scratch);
428 ZeroVector<double>(ns_, stateerr[t]);
429 ZeroVector<double>(na_, sourceerr[t]);
433 NetworkScratch::FloatVec sourceerr_temps[
WT_COUNT];
435 sourceerr_temps[w].Init(na_, scratch);
436 int width = input_width_;
438 NetworkScratch::GradientStore gate_errors_t[
WT_COUNT];
439 for (
int w = 0; w <
WT_COUNT; ++w) {
440 gate_errors_t[w].Init(ns_, width, scratch);
443 NetworkScratch::FloatVec softmax_errors;
444 NetworkScratch::GradientStore softmax_errors_t;
445 if (softmax_ != NULL) {
446 softmax_errors.Init(
no_, scratch);
447 softmax_errors_t.Init(
no_, width, scratch);
449 double state_clip =
Is2D() ? 9.0 : 4.0;
452 fwd_deltas.Print(10);
454 StrideMap::Index dest_index(input_map_);
455 dest_index.InitToLast();
457 StrideMap::Index src_index(fwd_deltas.stride_map());
458 src_index.InitToLast();
460 int t = dest_index.t();
461 bool at_last_x = dest_index.IsLast(
FD_WIDTH);
468 StrideMap::Index up_index(dest_index);
469 if (up_index.AddOffset(-1,
FD_HEIGHT)) up_pos = up_index.t();
472 StrideMap::Index down_index(dest_index);
473 if (down_index.AddOffset(1,
FD_HEIGHT)) down_pos = down_index.t();
477 int mod_t =
Modulo(t, buf_width);
480 ZeroVector<double>(na_, curr_sourceerr);
481 ZeroVector<double>(ns_, curr_stateerr);
486 fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
487 src_index.Decrement();
489 ZeroVector<double>(ns_, outputerr);
491 }
else if (softmax_ == NULL) {
492 fwd_deltas.ReadTimeStep(t, outputerr);
495 softmax_errors_t.get(), outputerr);
503 const float* next_node_gf1 = node_values_[
GF1].
f(t + 1);
504 for (
int i = 0; i < ns_; ++i) {
505 curr_stateerr[i] *= next_node_gf1[i];
508 if (
Is2D() && t + 1 < width) {
509 for (
int i = 0; i < ns_; ++i) {
510 if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
513 const float* right_node_gfs = node_values_[
GFS].
f(down_pos);
514 const double* right_stateerr = stateerr[mod_t];
515 for (
int i = 0; i < ns_; ++i) {
516 if (which_fg_[down_pos][i] == 2) {
517 curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
525 ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
527 if (t + 10 > width) {
529 for (
int i = 0; i < ns_; ++i)
530 tprintf(
" %g,%g,%g", curr_stateerr[i], outputerr[i],
531 curr_sourceerr[
ni_ + nf_ + i]);
539 node_values_[
CI].FuncMultiply3<GPrime>(t, node_values_[
GI], t,
540 curr_stateerr, gate_errors[
CI]);
543 gate_errors_t[
CI].get()->WriteStrided(t, gate_errors[CI]);
547 node_values_[
GI].FuncMultiply3<FPrime>(t, node_values_[
CI], t,
548 curr_stateerr, gate_errors[
GI]);
551 gate_errors_t[
GI].get()->WriteStrided(t, gate_errors[GI]);
556 node_values_[
GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
560 sourceerr_temps[GF1]);
562 memset(gate_errors[
GF1], 0, ns_ *
sizeof(gate_errors[GF1][0]));
563 memset(sourceerr_temps[GF1], 0, na_ *
sizeof(*sourceerr_temps[GF1]));
565 gate_errors_t[
GF1].get()->WriteStrided(t, gate_errors[GF1]);
569 node_values_[
GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
573 sourceerr_temps[GFS]);
575 memset(gate_errors[
GFS], 0, ns_ *
sizeof(gate_errors[GFS][0]));
576 memset(sourceerr_temps[GFS], 0, na_ *
sizeof(*sourceerr_temps[GFS]));
578 if (
Is2D()) gate_errors_t[
GFS].get()->WriteStrided(t, gate_errors[GFS]);
582 state_.Func2Multiply3<HFunc, FPrime>(node_values_[
GO], t, outputerr,
586 gate_errors_t[
GO].get()->WriteStrided(t, gate_errors[GO]);
589 SumVectors(na_, sourceerr_temps[CI], sourceerr_temps[GI],
590 sourceerr_temps[GF1], sourceerr_temps[GO], sourceerr_temps[GFS],
592 back_deltas->WriteTimeStep(t, curr_sourceerr);
595 CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
596 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
598 }
while (dest_index.Decrement());
600 for (
int w = 0; w <
WT_COUNT; ++w) {
602 gate_errors_t[w].get()->PrintUnTransposed(10);
606 NetworkScratch::GradientStore source_t, state_t;
607 source_t.Init(na_, width, scratch);
609 state_t.Init(ns_, width, scratch);
610 state_.Transpose(state_t.get());
612 #pragma omp parallel for num_threads(GFS) if (!Is2D()) 614 for (
int w = 0; w <
WT_COUNT; ++w) {
615 if (w == GFS && !
Is2D())
continue;
618 if (softmax_ != NULL) {
623 back_deltas->CopyWithNormalization(*back_deltas, fwd_deltas);
void AccumulateVector(int n, const double *src, double *dest)
void init_to_size(int size, T t)
void CopyVector(int n, const double *src, double *dest)
const char * string() const
void VectorDotMatrix(const double *u, double *v) const
#define SECTION_IF_OPENMP
int Size(FlexDimensions dimension) const
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
#define PARALLEL_IF_OPENMP(__num_threads)
void Transpose(TransposedArray *dest) const
void ClipVector(int n, T lower, T upper, T *vec)
#define END_PARALLEL_IF_OPENMP
void DisplayBackward(const NetworkIO &matrix)
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
void FinishBackward(const TransposedArray &errors_t)
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)