27 #if !defined(__GNUC__) && defined(_MSC_VER) 38 #define PARALLEL_IF_OPENMP(__num_threads) \ 39 PRAGMA(omp parallel if (__num_threads > 1) num_threads(__num_threads)) { \ 40 PRAGMA(omp sections nowait) { \ 42 #define SECTION_IF_OPENMP \ 47 #define END_PARALLEL_IF_OPENMP \ 53 #ifdef _MSC_VER // Different _Pragma 54 #define PRAGMA(x) __pragma(x) 56 #define PRAGMA(x) _Pragma(#x) 60 #define PARALLEL_IF_OPENMP(__num_threads) 61 #define SECTION_IF_OPENMP 62 #define END_PARALLEL_IF_OPENMP 75 static inline uint32_t ceil_log2(uint32_t n)
80 uint32_t l2 = 31 - __builtin_clz(n);
81 #elif defined(_MSC_VER) 84 _BitScanReverse(&l2, n);
86 if (n == 0)
return UINT_MAX;
96 return (n == (1u << l2)) ? l2 : l2 + 1;
105 is_2d_(two_dimensional),
108 if (two_dimensional) na_ += ns_;
131 if (softmax_ !=
nullptr)
return softmax_->
OutputShape(result);
146 for (
int w = 0; w <
WT_COUNT; ++w) {
147 if (w ==
GFS && !
Is2D())
continue;
161 for (
int w = 0; w <
WT_COUNT; ++w) {
162 if (w ==
GFS && !
Is2D())
continue;
166 if (softmax_ !=
nullptr) {
175 if (softmax_ !=
nullptr) {
184 for (
int w = 0; w <
WT_COUNT; ++w) {
185 if (w ==
GFS && !
Is2D())
continue;
188 if (softmax_ !=
nullptr) {
195 for (
int w = 0; w <
WT_COUNT; ++w) {
196 if (w ==
GFS && !
Is2D())
continue;
201 if (softmax_ !=
nullptr) {
210 for (
int w = 0; w <
WT_COUNT; ++w) {
211 if (w ==
GFS && !
Is2D())
continue;
214 if (softmax_ !=
nullptr && !softmax_->
Serialize(fp))
return false;
225 nf_ = ceil_log2(
no_);
230 for (
int w = 0; w <
WT_COUNT; ++w) {
231 if (w ==
GFS && !
Is2D())
continue;
235 is_2d_ = na_ - nf_ ==
ni_ + 2 * ns_;
241 if (softmax_ ==
nullptr)
return false;
254 input_width_ = input.
Width();
255 if (softmax_ !=
nullptr)
261 ResizeForward(input);
264 for (
int i = 0; i <
WT_COUNT; ++i) temp_lines[i].Init(ns_, scratch);
267 curr_state.
Init(ns_, scratch);
268 ZeroVector<double>(ns_, curr_state);
269 curr_output.
Init(ns_, scratch);
270 ZeroVector<double>(ns_, curr_output);
279 for (
int i = 0; i < buf_width; ++i) {
280 states[i].Init(ns_, scratch);
281 ZeroVector<double>(ns_, states[i]);
282 outputs[i].Init(ns_, scratch);
283 ZeroVector<double>(ns_, outputs[i]);
289 if (softmax_ !=
nullptr) {
290 softmax_output.
Init(
no_, scratch);
291 ZeroVector<double>(
no_, softmax_output);
292 int rounded_softmax_inputs = gate_weights_[
CI].
RoundInputs(ns_);
294 int_output.
Resize2d(
true, 1, rounded_softmax_inputs, scratch);
298 curr_input.
Init(na_, scratch);
303 int t = src_index.
t();
305 bool valid_2d =
Is2D();
311 int mod_t =
Modulo(t, buf_width);
314 if (softmax_ !=
nullptr) {
328 gate_weights_[
CI].MatrixDotVector(source_.
i(t), temp_lines[
CI]);
331 FuncInplace<GFunc>(ns_, temp_lines[
CI]);
336 gate_weights_[
GI].MatrixDotVector(source_.
i(t), temp_lines[
GI]);
339 FuncInplace<FFunc>(ns_, temp_lines[
GI]);
344 gate_weights_[
GF1].MatrixDotVector(source_.
i(t), temp_lines[
GF1]);
347 FuncInplace<FFunc>(ns_, temp_lines[
GF1]);
352 gate_weights_[
GFS].MatrixDotVector(source_.
i(t), temp_lines[
GFS]);
355 FuncInplace<FFunc>(ns_, temp_lines[
GFS]);
361 gate_weights_[
GO].MatrixDotVector(source_.
i(t), temp_lines[
GO]);
364 FuncInplace<FFunc>(ns_, temp_lines[
GO]);
371 int8_t* which_fg_col = which_fg_[t];
372 memset(which_fg_col, 1, ns_ *
sizeof(which_fg_col[0]));
374 const double* stepped_state = states[mod_t];
375 for (
int i = 0; i < ns_; ++i) {
376 if (temp_lines[
GF1][i] < temp_lines[
GFS][i]) {
377 curr_state[i] = temp_lines[
GFS][i] * stepped_state[i];
394 FuncMultiply<HFunc>(curr_state, temp_lines[
GO], ns_, curr_output);
396 if (softmax_ !=
nullptr) {
411 dest_index.Increment();
424 ZeroVector<double>(ns_, curr_state);
425 ZeroVector<double>(ns_, curr_output);
449 outputerr.
Init(ns_, scratch);
452 curr_stateerr.
Init(ns_, scratch);
453 curr_sourceerr.
Init(na_, scratch);
454 ZeroVector<double>(ns_, curr_stateerr);
455 ZeroVector<double>(na_, curr_sourceerr);
458 for (
int g = 0; g <
WT_COUNT; ++g) gate_errors[g].Init(ns_, scratch);
466 for (
int t = 0; t < buf_width; ++t) {
467 stateerr[t].Init(ns_, scratch);
468 sourceerr[t].Init(na_, scratch);
469 ZeroVector<double>(ns_, stateerr[t]);
470 ZeroVector<double>(na_, sourceerr[t]);
476 sourceerr_temps[w].Init(na_, scratch);
477 int width = input_width_;
480 for (
int w = 0; w <
WT_COUNT; ++w) {
481 gate_errors_t[w].
Init(ns_, width, scratch);
486 if (softmax_ !=
nullptr) {
487 softmax_errors.
Init(
no_, scratch);
488 softmax_errors_t.
Init(
no_, width, scratch);
490 double state_clip =
Is2D() ? 9.0 : 4.0;
493 fwd_deltas.
Print(10);
501 int t = dest_index.
t();
518 int mod_t =
Modulo(t, buf_width);
521 ZeroVector<double>(na_, curr_sourceerr);
522 ZeroVector<double>(ns_, curr_stateerr);
528 src_index.Decrement();
530 ZeroVector<double>(ns_, outputerr);
532 }
else if (softmax_ ==
nullptr) {
536 softmax_errors_t.
get(), outputerr);
544 const float* next_node_gf1 = node_values_[
GF1].
f(t + 1);
545 for (
int i = 0; i < ns_; ++i) {
546 curr_stateerr[i] *= next_node_gf1[i];
549 if (
Is2D() && t + 1 < width) {
550 for (
int i = 0; i < ns_; ++i) {
551 if (which_fg_[t + 1][i] != 1) curr_stateerr[i] = 0.0;
554 const float* right_node_gfs = node_values_[
GFS].
f(down_pos);
555 const double* right_stateerr = stateerr[mod_t];
556 for (
int i = 0; i < ns_; ++i) {
557 if (which_fg_[down_pos][i] == 2) {
558 curr_stateerr[i] += right_stateerr[i] * right_node_gfs[i];
566 ClipVector<double>(ns_, -state_clip, state_clip, curr_stateerr);
568 if (t + 10 > width) {
570 for (
int i = 0; i < ns_; ++i)
571 tprintf(
" %g,%g,%g", curr_stateerr[i], outputerr[i],
572 curr_sourceerr[
ni_ + nf_ + i]);
580 node_values_[
CI].FuncMultiply3<
GPrime>(t, node_values_[
GI], t,
581 curr_stateerr, gate_errors[
CI]);
588 node_values_[
GI].FuncMultiply3<
FPrime>(t, node_values_[
CI], t,
589 curr_stateerr, gate_errors[
GI]);
597 node_values_[
GF1].FuncMultiply3<
FPrime>(t, state_, t - 1, curr_stateerr,
601 sourceerr_temps[
GF1]);
603 memset(gate_errors[
GF1], 0, ns_ *
sizeof(gate_errors[
GF1][0]));
604 memset(sourceerr_temps[
GF1], 0, na_ *
sizeof(*sourceerr_temps[
GF1]));
610 node_values_[
GFS].FuncMultiply3<
FPrime>(t, state_, up_pos, curr_stateerr,
614 sourceerr_temps[
GFS]);
616 memset(gate_errors[
GFS], 0, ns_ *
sizeof(gate_errors[
GFS][0]));
617 memset(sourceerr_temps[
GFS], 0, na_ *
sizeof(*sourceerr_temps[
GFS]));
631 sourceerr_temps[
GF1], sourceerr_temps[
GO], sourceerr_temps[
GFS],
636 CopyVector(ns_, curr_stateerr, stateerr[mod_t]);
637 CopyVector(na_, curr_sourceerr, sourceerr[mod_t]);
641 for (
int w = 0; w <
WT_COUNT; ++w) {
648 source_t.
Init(na_, width, scratch);
650 state_t.
Init(ns_, width, scratch);
653 #pragma omp parallel for num_threads(GFS) if (!Is2D()) 655 for (
int w = 0; w <
WT_COUNT; ++w) {
656 if (w ==
GFS && !
Is2D())
continue;
659 if (softmax_ !=
nullptr) {
667 void LSTM::Update(
float learning_rate,
float momentum,
float adam_beta,
672 for (
int w = 0; w <
WT_COUNT; ++w) {
673 if (w ==
GFS && !
Is2D())
continue;
674 gate_weights_[w].
Update(learning_rate, momentum, adam_beta, num_samples);
676 if (softmax_ !=
nullptr) {
677 softmax_->
Update(learning_rate, momentum, adam_beta, num_samples);
688 double* changed)
const {
690 const LSTM* lstm =
static_cast<const LSTM*
>(&other);
691 for (
int w = 0; w <
WT_COUNT; ++w) {
692 if (w ==
GFS && !
Is2D())
continue;
695 if (softmax_ !=
nullptr) {
703 for (
int w = 0; w <
WT_COUNT; ++w) {
704 if (w ==
GFS && !
Is2D())
continue;
705 tprintf(
"Gate %d, inputs\n", w);
706 for (
int i = 0; i <
ni_; ++i) {
708 for (
int s = 0; s < ns_; ++s)
709 tprintf(
" %g", gate_weights_[w].GetWeights(s)[i]);
712 tprintf(
"Gate %d, outputs\n", w);
713 for (
int i =
ni_; i <
ni_ + ns_; ++i) {
715 for (
int s = 0; s < ns_; ++s)
716 tprintf(
" %g", gate_weights_[w].GetWeights(s)[i]);
720 for (
int s = 0; s < ns_; ++s)
721 tprintf(
" %g", gate_weights_[w].GetWeights(s)[na_]);
729 for (
int w = 0; w <
WT_COUNT; ++w) {
730 if (w ==
GFS && !
Is2D())
continue;
731 tprintf(
"Gate %d, inputs\n", w);
732 for (
int i = 0; i <
ni_; ++i) {
734 for (
int s = 0; s < ns_; ++s)
735 tprintf(
" %g", gate_weights_[w].GetDW(s, i));
738 tprintf(
"Gate %d, outputs\n", w);
739 for (
int i =
ni_; i <
ni_ + ns_; ++i) {
741 for (
int s = 0; s < ns_; ++s)
742 tprintf(
" %g", gate_weights_[w].GetDW(s, i));
746 for (
int s = 0; s < ns_; ++s)
747 tprintf(
" %g", gate_weights_[w].GetDW(s, na_));
753 void LSTM::ResizeForward(
const NetworkIO& input) {
755 source_.
Resize(input, rounded_inputs);
759 for (
int w = 0; w <
WT_COUNT; ++w) {
760 if (w ==
GFS && !
Is2D())
continue;
void Init(int size, NetworkScratch *scratch)
int RoundInputs(int size) const
LSTM(const STRING &name, int num_inputs, int num_states, int num_outputs, bool two_dimensional, NetworkType type)
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
static Network * CreateFromFile(TFile *fp)
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
void Print(int num) const
#define END_PARALLEL_IF_OPENMP
void ConvertToInt() override
void PrintUnTransposed(int num)
bool Serialize(TFile *fp) const override
void CodeInBinary(int n, int nf, double *vec)
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
bool AddOffset(int offset, FlexDimensions dimension)
const char * string() const
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void Init(int size1, int size2, NetworkScratch *scratch)
bool DeSerialize(char *data, size_t count=1)
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
void Resize2d(bool int_mode, int width, int num_features, NetworkScratch *scratch)
void DisplayBackward(const NetworkIO &matrix)
int InitWeights(float range, TRand *randomizer) override
virtual void SetRandomizer(TRand *randomizer)
void VectorDotMatrix(const double *u, double *v) const
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
void WriteTimeStep(int t, const double *input)
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
void CopyVector(int n, const double *src, double *dest)
const int8_t * i(int t) const
void ResizeNoInit(int size1, int size2, int pad=0)
void ClipVector(int n, T lower, T upper, T *vec)
bool IsLast(FlexDimensions dimension) const
void MultiplyVectorsInPlace(int n, const double *src, double *inout)
int Size(FlexDimensions dimension) const
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
void set_width(int value)
void ConvertToInt() override
void Resize(const NetworkIO &src, int num_features)
void MultiplyAccumulate(int n, const double *u, const double *v, double *out)
void Debug2D(const char *msg)
bool Serialize(TFile *fp) const override
bool DeSerialize(TFile *fp) override
void Transpose(TransposedArray *dest) const
StaticShape OutputShape(const StaticShape &input_shape) const override
virtual bool Serialize(TFile *fp) const
int index(FlexDimensions dimension) const
void init_to_size(int size, const T &t)
void DebugWeights() override
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
void SetEnableTraining(TrainingState state) override
void SetEnableTraining(TrainingState state) override
TransposedArray * get() const
void DisplayForward(const NetworkIO &matrix)
void Func2Multiply3(const NetworkIO &v_io, int t, const double *w, double *product) const
StaticShape OutputShape(const StaticShape &input_shape) const override
void CountAlternators(const Network &other, double *same, double *changed) const override
bool Serialize(const char *data, size_t count=1)
const StrideMap & stride_map() const
DLLSYM void tprintf(const char *format,...)
void MatrixDotVector(const double *u, double *v) const
void WriteTimeStepPart(int t, int offset, int num_features, const double *input)
void add_str_int(const char *str, int number)
void ForwardTimeStep(int t, double *output_line)
void FinishBackward(const TransposedArray &errors_t)
void CopyTimeStepGeneral(int dest_t, int dest_offset, int num_features, const NetworkIO &src, int src_t, int src_offset)
void SumVectors(int n, const double *v1, const double *v2, const double *v3, const double *v4, const double *v5, double *sum)
#define SECTION_IF_OPENMP
void AccumulateVector(int n, const double *src, double *dest)
#define PARALLEL_IF_OPENMP(__num_threads)
void set_depth(int value)
void FuncMultiply3Add(const NetworkIO &v_io, int t, const double *w, double *product) const
void CountAlternators(const Network &other, double *same, double *changed) const override
void ResizeToMap(bool int_mode, const StrideMap &stride_map, int num_features)
void ReadTimeStep(int t, double *output) const
int InitWeights(float range, TRand *randomizer) override
void WriteStrided(int t, const float *data)
void DebugWeights() override
bool TestFlag(NetworkFlags flag) const
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
void ResizeXTo1(const NetworkIO &src, int num_features)
void ResizeFloat(const NetworkIO &src, int num_features)
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)