444 back_deltas->ResizeToMap(fwd_deltas.int_mode(), input_map_,
ni_);
447 NetworkScratch::FloatVec outputerr;
448 outputerr.Init(ns_, scratch);
450 NetworkScratch::FloatVec curr_stateerr, curr_sourceerr;
451 curr_stateerr.Init(ns_, scratch);
452 curr_sourceerr.Init(na_, scratch);
453 ZeroVector<double>(ns_, curr_stateerr);
454 ZeroVector<double>(na_, curr_sourceerr);
456 NetworkScratch::FloatVec gate_errors[
WT_COUNT];
457 for (
auto & gate_error : gate_errors) gate_error.Init(ns_, scratch);
463 stateerr.
init_to_size(buf_width, NetworkScratch::FloatVec());
464 sourceerr.
init_to_size(buf_width, NetworkScratch::FloatVec());
465 for (
int t = 0; t < buf_width; ++t) {
466 stateerr[t].Init(ns_, scratch);
467 sourceerr[t].Init(na_, scratch);
468 ZeroVector<double>(ns_, stateerr[t]);
469 ZeroVector<double>(na_, sourceerr[t]);
473 NetworkScratch::FloatVec sourceerr_temps[
WT_COUNT];
474 for (
auto & sourceerr_temp : sourceerr_temps)
475 sourceerr_temp.Init(na_, scratch);
476 int width = input_width_;
478 NetworkScratch::GradientStore gate_errors_t[
WT_COUNT];
479 for (
auto & w : gate_errors_t) {
480 w.Init(ns_, width, scratch);
483 NetworkScratch::FloatVec softmax_errors;
484 NetworkScratch::GradientStore softmax_errors_t;
485 if (softmax_ !=
nullptr) {
486 softmax_errors.Init(
no_, scratch);
487 softmax_errors_t.Init(
no_, width, scratch);
489 double state_clip =
Is2D() ? 9.0 : 4.0;
492 fwd_deltas.Print(10);
494 StrideMap::Index dest_index(input_map_);
495 dest_index.InitToLast();
497 StrideMap::Index src_index(fwd_deltas.stride_map());
498 src_index.InitToLast();
500 int t = dest_index.t();
501 bool at_last_x = dest_index.IsLast(
FD_WIDTH);
508 StrideMap::Index up_index(dest_index);
509 if (up_index.AddOffset(-1,
FD_HEIGHT)) up_pos = up_index.t();
512 StrideMap::Index down_index(dest_index);
513 if (down_index.AddOffset(1,
FD_HEIGHT)) down_pos = down_index.t();
517 int mod_t =
Modulo(t, buf_width);
520 ZeroVector<double>(na_, curr_sourceerr);
521 ZeroVector<double>(ns_, curr_stateerr);
526 fwd_deltas.ReadTimeStep(src_index.t(), outputerr);
527 src_index.Decrement();
529 ZeroVector<double>(ns_, outputerr);
531 }
else if (softmax_ ==
nullptr) {
532 fwd_deltas.ReadTimeStep(t, outputerr);
535 softmax_errors_t.get(), outputerr);
543 const float* next_node_gf1 = node_values_[
GF1].
f(t + 1);
544 for (
int i = 0; i < ns_; ++i) {
545 curr_stateerr[i] *= next_node_gf1[i];
548 if (
Is2D() && t + 1 < width) {
549 for (
int i = 0; i < ns_; ++i) {
550 if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
553 const float* right_node_gfs = node_values_[
GFS].
f(down_pos);
554 const double* right_stateerr = stateerr[mod_t];
555 for (
int i = 0; i < ns_; ++i) {
556 if (which_fg_[down_pos][i] == 2) {
557 curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
565 ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
567 if (t + 10 > width) {
569 for (
int i = 0; i < ns_; ++i)
570 tprintf(
" %g,%g,%g", curr_stateerr[i], outputerr[i],
571 curr_sourceerr[
ni_ + nf_ + i]);
579 node_values_[
CI].FuncMultiply3<GPrime>(t, node_values_[
GI], t,
580 curr_stateerr, gate_errors[
CI]);
583 gate_errors_t[
CI].get()->WriteStrided(t, gate_errors[
CI]);
587 node_values_[
GI].FuncMultiply3<FPrime>(t, node_values_[
CI], t,
588 curr_stateerr, gate_errors[
GI]);
591 gate_errors_t[
GI].get()->WriteStrided(t, gate_errors[
GI]);
596 node_values_[
GF1].FuncMultiply3<FPrime>(t, state_, t - 1, curr_stateerr,
600 sourceerr_temps[
GF1]);
602 memset(gate_errors[
GF1], 0, ns_ *
sizeof(gate_errors[
GF1][0]));
603 memset(sourceerr_temps[
GF1], 0, na_ *
sizeof(*sourceerr_temps[
GF1]));
605 gate_errors_t[
GF1].get()->WriteStrided(t, gate_errors[
GF1]);
609 node_values_[
GFS].FuncMultiply3<FPrime>(t, state_, up_pos, curr_stateerr,
613 sourceerr_temps[
GFS]);
615 memset(gate_errors[
GFS], 0, ns_ *
sizeof(gate_errors[
GFS][0]));
616 memset(sourceerr_temps[
GFS], 0, na_ *
sizeof(*sourceerr_temps[
GFS]));
618 if (
Is2D()) gate_errors_t[
GFS].get()->WriteStrided(t, gate_errors[
GFS]);
622 state_.Func2Multiply3<HFunc, FPrime>(node_values_[
GO], t, outputerr,
626 gate_errors_t[
GO].get()->WriteStrided(t, gate_errors[
GO]);
630 sourceerr_temps[
GF1], sourceerr_temps[
GO], sourceerr_temps[
GFS],
632 back_deltas->WriteTimeStep(t, curr_sourceerr);
635 CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
636 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
638 }
while (dest_index.Decrement());
640 for (
int w = 0; w <
WT_COUNT; ++w) {
642 gate_errors_t[w].get()->PrintUnTransposed(10);
646 NetworkScratch::GradientStore source_t, state_t;
647 source_t.Init(na_, width, scratch);
649 state_t.Init(ns_, width, scratch);
650 state_.Transpose(state_t.get());
652 #pragma omp parallel for num_threads(GFS) if (!Is2D())
654 for (
int w = 0; w <
WT_COUNT; ++w) {
655 if (w ==
GFS && !
Is2D())
continue;
658 if (softmax_ !=
nullptr) {