26 #if !defined(__GNUC__) && defined(_MSC_VER)
37 #define PARALLEL_IF_OPENMP(__num_threads) \
38 PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \
39 PRAGMA(omp sections nowait) { \
41 #define SECTION_IF_OPENMP \
46 #define END_PARALLEL_IF_OPENMP \
52 #ifdef _MSC_VER // Different _Pragma
53 #define PRAGMA(x) __pragma(x)
55 #define PRAGMA(x) _Pragma(#x)
59 #define PARALLEL_IF_OPENMP(__num_threads)
60 #define SECTION_IF_OPENMP
61 #define END_PARALLEL_IF_OPENMP
74 static inline uint32_t ceil_log2(uint32_t n)
79 uint32_t l2 = 31 - __builtin_clz(n);
80 #elif defined(_MSC_VER)
83 _BitScanReverse(&l2, n);
85 if (n == 0)
return UINT_MAX;
95 return (n == (1u << l2)) ? l2 : l2 + 1;
104 is_2d_(two_dimensional),
107 if (two_dimensional) na_ += ns_;
130 if (softmax_ !=
nullptr)
return softmax_->
OutputShape(result);
145 for (
int w = 0; w <
WT_COUNT; ++w) {
146 if (w ==
GFS && !
Is2D())
continue;
160 for (
int w = 0; w <
WT_COUNT; ++w) {
161 if (w ==
GFS && !
Is2D())
continue;
165 if (softmax_ !=
nullptr) {
174 if (softmax_ !=
nullptr) {
183 for (
int w = 0; w <
WT_COUNT; ++w) {
184 if (w ==
GFS && !
Is2D())
continue;
187 if (softmax_ !=
nullptr) {
194 for (
int w = 0; w <
WT_COUNT; ++w) {
195 if (w ==
GFS && !
Is2D())
continue;
200 if (softmax_ !=
nullptr) {
209 for (
int w = 0; w <
WT_COUNT; ++w) {
210 if (w ==
GFS && !
Is2D())
continue;
213 if (softmax_ !=
nullptr && !softmax_->
Serialize(fp))
return false;
224 nf_ = ceil_log2(
no_);
229 for (
int w = 0; w <
WT_COUNT; ++w) {
230 if (w ==
GFS && !
Is2D())
continue;
234 is_2d_ = na_ - nf_ ==
ni_ + 2 * ns_;
240 if (softmax_ ==
nullptr)
return false;
253 input_width_ = input.
Width();
254 if (softmax_ !=
nullptr)
260 ResizeForward(input);
263 for (
auto & temp_line : temp_lines) temp_line.
Init(ns_, scratch);
266 curr_state.
Init(ns_, scratch);
267 ZeroVector<double>(ns_, curr_state);
268 curr_output.
Init(ns_, scratch);
269 ZeroVector<double>(ns_, curr_output);
278 for (
int i = 0; i < buf_width; ++i) {
279 states[i].Init(ns_, scratch);
280 ZeroVector<double>(ns_, states[i]);
281 outputs[i].Init(ns_, scratch);
282 ZeroVector<double>(ns_, outputs[i]);
288 if (softmax_ !=
nullptr) {
289 softmax_output.
Init(
no_, scratch);
290 ZeroVector<double>(
no_, softmax_output);
291 int rounded_softmax_inputs = gate_weights_[
CI].
RoundInputs(ns_);
293 int_output.
Resize2d(
true, 1, rounded_softmax_inputs, scratch);
297 curr_input.
Init(na_, scratch);
302 int t = src_index.
t();
304 bool valid_2d =
Is2D();
310 int mod_t =
Modulo(t, buf_width);
313 if (softmax_ !=
nullptr) {
327 gate_weights_[
CI].MatrixDotVector(source_.
i(t), temp_lines[
CI]);
330 FuncInplace<GFunc>(ns_, temp_lines[
CI]);
335 gate_weights_[
GI].MatrixDotVector(source_.
i(t), temp_lines[
GI]);
338 FuncInplace<FFunc>(ns_, temp_lines[
GI]);
343 gate_weights_[
GF1].MatrixDotVector(source_.
i(t), temp_lines[
GF1]);
346 FuncInplace<FFunc>(ns_, temp_lines[
GF1]);
351 gate_weights_[
GFS].MatrixDotVector(source_.
i(t), temp_lines[
GFS]);
354 FuncInplace<FFunc>(ns_, temp_lines[
GFS]);
360 gate_weights_[
GO].MatrixDotVector(source_.
i(t), temp_lines[
GO]);
363 FuncInplace<FFunc>(ns_, temp_lines[
GO]);
370 int8_t* which_fg_col = which_fg_[t];
371 memset(which_fg_col, 1, ns_ *
sizeof(which_fg_col[0]));
373 const double* stepped_state = states[mod_t];
374 for (
int i = 0; i < ns_; ++i) {
375 if (temp_lines[
GF1][i] < temp_lines[
GFS][i]) {
376 curr_state[i] = temp_lines[
GFS][i] * stepped_state[i];
393 FuncMultiply<HFunc>(curr_state, temp_lines[
GO], ns_, curr_output);
395 if (softmax_ !=
nullptr) {
410 dest_index.Increment();
423 ZeroVector<double>(ns_, curr_state);
424 ZeroVector<double>(ns_, curr_output);
448 outputerr.
Init(ns_, scratch);
451 curr_stateerr.
Init(ns_, scratch);
452 curr_sourceerr.
Init(na_, scratch);
453 ZeroVector<double>(ns_, curr_stateerr);
454 ZeroVector<double>(na_, curr_sourceerr);
457 for (
auto & gate_error : gate_errors) gate_error.
Init(ns_, scratch);
465 for (
int t = 0; t < buf_width; ++t) {
466 stateerr[t].Init(ns_, scratch);
467 sourceerr[t].Init(na_, scratch);
468 ZeroVector<double>(ns_, stateerr[t]);
469 ZeroVector<double>(na_, sourceerr[t]);
474 for (
auto & sourceerr_temp : sourceerr_temps)
475 sourceerr_temp.
Init(na_, scratch);
476 int width = input_width_;
479 for (
auto & w : gate_errors_t) {
480 w.
Init(ns_, width, scratch);
485 if (softmax_ !=
nullptr) {
486 softmax_errors.
Init(
no_, scratch);
487 softmax_errors_t.
Init(
no_, width, scratch);
489 double state_clip =
Is2D() ? 9.0 : 4.0;
492 fwd_deltas.
Print(10);
500 int t = dest_index.
t();
517 int mod_t =
Modulo(t, buf_width);
520 ZeroVector<double>(na_, curr_sourceerr);
521 ZeroVector<double>(ns_, curr_stateerr);
527 src_index.Decrement();
529 ZeroVector<double>(ns_, outputerr);
531 }
else if (softmax_ ==
nullptr) {
535 softmax_errors_t.
get(), outputerr);
543 const float* next_node_gf1 = node_values_[
GF1].
f(t + 1);
544 for (
int i = 0; i < ns_; ++i) {
545 curr_stateerr[i] *= next_node_gf1[i];
548 if (
Is2D() && t + 1 < width) {
549 for (
int i = 0; i < ns_; ++i) {
550 if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
553 const float* right_node_gfs = node_values_[
GFS].
f(down_pos);
554 const double* right_stateerr = stateerr[mod_t];
555 for (
int i = 0; i < ns_; ++i) {
556 if (which_fg_[down_pos][i] == 2) {
557 curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
565 ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
567 if (t + 10 > width) {
569 for (
int i = 0; i < ns_; ++i)
570 tprintf(
" %g,%g,%g", curr_stateerr[i], outputerr[i],
571 curr_sourceerr[
ni_ + nf_ + i]);
579 node_values_[
CI].FuncMultiply3<
GPrime>(t, node_values_[
GI], t,
580 curr_stateerr, gate_errors[
CI]);
583 gate_errors_t[
CI].get()->WriteStrided(t, gate_errors[
CI]);
587 node_values_[
GI].FuncMultiply3<
FPrime>(t, node_values_[
CI], t,
588 curr_stateerr, gate_errors[
GI]);
591 gate_errors_t[
GI].get()->WriteStrided(t, gate_errors[
GI]);
596 node_values_[
GF1].FuncMultiply3<
FPrime>(t, state_, t - 1, curr_stateerr,
600 sourceerr_temps[
GF1]);
602 memset(gate_errors[
GF1], 0, ns_ *
sizeof(gate_errors[
GF1][0]));
603 memset(sourceerr_temps[
GF1], 0, na_ *
sizeof(*sourceerr_temps[
GF1]));
605 gate_errors_t[
GF1].get()->WriteStrided(t, gate_errors[
GF1]);
609 node_values_[
GFS].FuncMultiply3<
FPrime>(t, state_, up_pos, curr_stateerr,
613 sourceerr_temps[
GFS]);
615 memset(gate_errors[
GFS], 0, ns_ *
sizeof(gate_errors[
GFS][0]));
616 memset(sourceerr_temps[
GFS], 0, na_ *
sizeof(*sourceerr_temps[
GFS]));
618 if (
Is2D()) gate_errors_t[
GFS].get()->WriteStrided(t, gate_errors[
GFS]);
626 gate_errors_t[
GO].get()->WriteStrided(t, gate_errors[
GO]);
630 sourceerr_temps[
GF1], sourceerr_temps[
GO], sourceerr_temps[
GFS],
635 CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
636 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
640 for (
int w = 0; w <
WT_COUNT; ++w) {
642 gate_errors_t[w].get()->PrintUnTransposed(10);
647 source_t.
Init(na_, width, scratch);
649 state_t.
Init(ns_, width, scratch);
652 #pragma omp parallel for num_threads(GFS) if (!Is2D())
654 for (
int w = 0; w <
WT_COUNT; ++w) {
655 if (w ==
GFS && !
Is2D())
continue;
658 if (softmax_ !=
nullptr) {
666 void LSTM::Update(
float learning_rate,
float momentum,
float adam_beta,
671 for (
int w = 0; w <
WT_COUNT; ++w) {
672 if (w ==
GFS && !
Is2D())
continue;
673 gate_weights_[w].
Update(learning_rate, momentum, adam_beta, num_samples);
675 if (softmax_ !=
nullptr) {
676 softmax_->
Update(learning_rate, momentum, adam_beta, num_samples);
687 double* changed)
const {
689 const LSTM* lstm = static_cast<const LSTM*>(&other);
690 for (
int w = 0; w <
WT_COUNT; ++w) {
691 if (w ==
GFS && !
Is2D())
continue;
694 if (softmax_ !=
nullptr) {
702 for (
int w = 0; w <
WT_COUNT; ++w) {
703 if (w ==
GFS && !
Is2D())
continue;
704 tprintf(
"Gate %d, inputs\n", w);
705 for (
int i = 0; i <
ni_; ++i) {
707 for (
int s = 0; s < ns_; ++s)
708 tprintf(
" %g", gate_weights_[w].GetWeights(s)[i]);
711 tprintf(
"Gate %d, outputs\n", w);
712 for (
int i =
ni_; i <
ni_ + ns_; ++i) {
714 for (
int s = 0; s < ns_; ++s)
715 tprintf(
" %g", gate_weights_[w].GetWeights(s)[i]);
719 for (
int s = 0; s < ns_; ++s)
720 tprintf(
" %g", gate_weights_[w].GetWeights(s)[na_]);
728 for (
int w = 0; w <
WT_COUNT; ++w) {
729 if (w ==
GFS && !
Is2D())
continue;
730 tprintf(
"Gate %d, inputs\n", w);
731 for (
int i = 0; i <
ni_; ++i) {
733 for (
int s = 0; s < ns_; ++s)
734 tprintf(
" %g", gate_weights_[w].GetDW(s, i));
737 tprintf(
"Gate %d, outputs\n", w);
738 for (
int i =
ni_; i <
ni_ + ns_; ++i) {
740 for (
int s = 0; s < ns_; ++s)
741 tprintf(
" %g", gate_weights_[w].GetDW(s, i));
745 for (
int s = 0; s < ns_; ++s)
746 tprintf(
" %g", gate_weights_[w].GetDW(s, na_));
752 void LSTM::ResizeForward(
const NetworkIO& input) {
754 source_.
Resize(input, rounded_inputs);
758 for (
int w = 0; w <
WT_COUNT; ++w) {
759 if (w ==
GFS && !
Is2D())
continue;