445 back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_,
ni_);
448 NetworkScratch::FloatVec outputerr;
449 outputerr.Init(ns_, scratch);
451 NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
452 curr_stateerr.Init(ns_, scratch);
453 curr_sourceerr.Init(na_, scratch);
454 ZeroVector<double>(ns_, curr_stateerr);
455 ZeroVector<double>(na_, curr_sourceerr);
457 NetworkScratch::FloatVec gate_errors[
WT_COUNT];
458 for (
int g = 0; g <
WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
464 stateerr.
init_to_size(buf_width, NetworkScratch::FloatVec());
465 sourceerr.
init_to_size(buf_width, NetworkScratch::FloatVec());
466 for (
int t = 0; t < buf_width; ++t) {
467 stateerr[t].Init(ns_, scratch);
468 sourceerr[t].Init(na_, scratch);
469 ZeroVector<double>(ns_, stateerr[t]);
470 ZeroVector<double>(na_, sourceerr[t]);
474 NetworkScratch::FloatVec sourceerr_temps[
WT_COUNT];
476 sourceerr_temps[w].Init(na_, scratch);
477 int width = input_width_;
479 NetworkScratch::GradientStore gate_errors_t[
WT_COUNT];
480 for (
int w = 0; w <
WT_COUNT; ++w) {
481 gate_errors_t[w].Init(ns_, width, scratch);
484 NetworkScratch::FloatVec softmax_errors;
485 NetworkScratch::GradientStore softmax_errors_t;
486 if (softmax_ !=
nullptr) {
487 softmax_errors.Init(
no_, scratch);
488 softmax_errors_t.Init(
no_, width, scratch);
490 double state_clip =
Is2D() ? 9.0 : 4.0;
493 fwd_deltas.Print(10);
495 StrideMap::Index dest_index(input_map_);
496 dest_index.InitToLast();
498 StrideMap::Index src_index(fwd_deltas.stride_map());
499 src_index.InitToLast();
501 int t = dest_index.t();
502 bool at_last_x = dest_index.IsLast(
FD_WIDTH);
509 StrideMap::Index up_index(dest_index);
510 if (up_index.AddOffset(-1,
FD_HEIGHT)) up_pos = up_index.t();
513 StrideMap::Index down_index(dest_index);
514 if (down_index.AddOffset(1,
FD_HEIGHT)) down_pos = down_index.t();
518 int mod_t =
Modulo(t, buf_width);
521 ZeroVector<double>(na_, curr_sourceerr);
522 ZeroVector<double>(ns_, curr_stateerr);
527 fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
528 src_index.Decrement();
530 ZeroVector<double>(ns_, outputerr);
532 }
else if (softmax_ ==
nullptr) {
533 fwd_deltas.ReadTimeStep(t, outputerr);
536 softmax_errors_t.get(), outputerr);
544 const float* next_node_gf1 = node_values_[
GF1].
f(t + 1);
545 for (
int i = 0; i < ns_; ++i) {
546 curr_stateerr[i] *= next_node_gf1[i];
549 if (
Is2D() && t + 1 < width) {
550 for (
int i = 0; i < ns_; ++i) {
551 if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
554 const float* right_node_gfs = node_values_[
GFS].
f(down_pos);
555 const double* right_stateerr = stateerr[mod_t];
556 for (
int i = 0; i < ns_; ++i) {
557 if (which_fg_[down_pos][i] == 2) {
558 curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
566 ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
568 if (t + 10 > width) {
570 for (
int i = 0; i < ns_; ++i)
571 tprintf(
" %g,%g,%g", curr_stateerr[i], outputerr[i],
572 curr_sourceerr[
ni_ + nf_ + i]);
580 node_values_[
CI].FuncMultiply3<GPrime>(t, node_values_[
GI], t,
581 curr_stateerr, gate_errors[
CI]);
584 gate_errors_t[
CI].get()->WriteStrided(t, gate_errors[
CI]);
588 node_values_[
GI].FuncMultiply3<FPrime>(t, node_values_[
CI], t,
589 curr_stateerr, gate_errors[
GI]);
592 gate_errors_t[
GI].get()->WriteStrided(t, gate_errors[
GI]);
597 node_values_[
GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
601 sourceerr_temps[
GF1]);
603 memset(gate_errors[
GF1], 0, ns_ *
sizeof(gate_errors[
GF1][0]));
604 memset(sourceerr_temps[
GF1], 0, na_ *
sizeof(*sourceerr_temps[
GF1]));
606 gate_errors_t[
GF1].get()->WriteStrided(t, gate_errors[
GF1]);
610 node_values_[
GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
614 sourceerr_temps[
GFS]);
616 memset(gate_errors[
GFS], 0, ns_ *
sizeof(gate_errors[
GFS][0]));
617 memset(sourceerr_temps[
GFS], 0, na_ *
sizeof(*sourceerr_temps[
GFS]));
619 if (
Is2D()) gate_errors_t[
GFS].get()->WriteStrided(t, gate_errors[
GFS]);
623 state_.Func2Multiply3<HFunc, FPrime>(node_values_[
GO], t, outputerr,
627 gate_errors_t[
GO].get()->WriteStrided(t, gate_errors[
GO]);
631 sourceerr_temps[
GF1], sourceerr_temps[
GO], sourceerr_temps[
GFS],
633 back_deltas->WriteTimeStep(t, curr_sourceerr);
636 CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
637 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
639 }
while (dest_index.Decrement());
641 for (
int w = 0; w <
WT_COUNT; ++w) {
643 gate_errors_t[w].get()->PrintUnTransposed(10);
647 NetworkScratch::GradientStore source_t, state_t;
648 source_t.Init(na_, width, scratch);
650 state_t.Init(ns_, width, scratch);
651 state_.Transpose(state_t.get());
653 #pragma omp parallel for num_threads(GFS) if (!Is2D()) 655 for (
int w = 0; w <
WT_COUNT; ++w) {
656 if (w ==
GFS && !
Is2D())
continue;
659 if (softmax_ !=
nullptr) {
#define END_PARALLEL_IF_OPENMP
const char * string() const
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void DisplayBackward(const NetworkIO &matrix)
void VectorDotMatrix(const double *u, double *v) const
void CopyVector(int n, const double *src, double *dest)
void ClipVector(int n, T lower, T upper, T *vec)
int Size(FlexDimensions dimension) const
void Transpose(TransposedArray *dest) const
void init_to_size(int size, const T &t)
DLLSYM void tprintf(const char *format,...)
void FinishBackward(const TransposedArray &errors_t)
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
#define SECTION_IF_OPENMP
void AccumulateVector(int n, const double *src, double *dest)
#define PARALLEL_IF_OPENMP(__num_threads)
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)