tesseract
5.0.0-alpha-619-ge9db
|
Go to the documentation of this file.
40 :
Network(
type, name, ni, no), external_source_(nullptr), int_mode_(false) {
89 no_ = code_map.size();
121 int width = input.
Width();
132 temp_lines[i].Init(
no_, scratch);
133 curr_input[i].Init(
ni_, scratch);
136 #pragma omp parallel for num_threads(kNumThreads)
137 for (
int t = 0; t < width; ++t) {
139 int thread_id = omp_get_thread_num();
141 for (
int t = 0; t < width; ++t) {
145 double* temp_line = temp_lines[thread_id];
186 FuncInplace<GFunc>(
no_, output_line);
188 FuncInplace<FFunc>(
no_, output_line);
190 FuncInplace<ClipFFunc>(
no_, output_line);
192 FuncInplace<ClipGFunc>(
no_, output_line);
194 FuncInplace<Relu>(
no_, output_line);
198 ASSERT_HOST(
"Invalid fully-connected type!" ==
nullptr);
203 int t,
double* output_line) {
212 int t,
double* output_line) {
227 for (
int i = 0; i <
kNumThreads; ++i) errors[i].Init(
no_, scratch);
231 for (
int i = 0; i <
kNumThreads; ++i) temp_backprops[i].Init(
ni_, scratch);
233 int width = fwd_deltas.
Width();
235 errors_t.
Init(
no_, width, scratch);
237 #pragma omp parallel for num_threads(kNumThreads)
238 for (
int t = 0; t < width; ++t) {
239 int thread_id = omp_get_thread_num();
241 for (
int t = 0; t < width; ++t) {
244 double* backprop =
nullptr;
246 double* curr_errors = errors[thread_id];
248 if (backprop !=
nullptr) {
257 back_deltas->
Print(10);
280 fwd_deltas.ReadTimeStep(t, curr_errors);
282 ASSERT_HOST(
"Invalid fully-connected type!" ==
nullptr);
298 float adam_beta,
int num_samples) {
306 double* changed)
const {
308 const auto* fc = static_cast<const FullyConnected*>(&other);
const TransposedArray * external_source_
int InitWeights(float range, TRand *randomizer) override
TransposedArray source_t_
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
const int8_t * i(int t) const
void FinishBackward(const TransposedArray &errors_t)
virtual void SetRandomizer(TRand *randomizer)
void DisplayForward(const NetworkIO &matrix)
void Update(float learning_rate, float momentum, float adam_beta, int num_samples) override
bool DeSerialize(TFile *fp) override
void ZeroInvalidElements()
bool Backward(bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void CountAlternators(const Network &other, double *same, double *changed) const override
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
int RemapOutputs(int old_no, const std::vector< int > &code_map) override
void WriteStrided(int t, const float *data)
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
void FuncMultiply(const NetworkIO &v_io, int t, double *product)
void ResizeFloat(const NetworkIO &src, int num_features)
void Debug2D(const char *msg)
void Init(int size1, int size2, NetworkScratch *scratch)
bool TestFlag(NetworkFlags flag) const
void SetEnableTraining(TrainingState state) override
TransposedArray * get() const
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void SoftmaxInPlace(int n, T *inout)
void Forward(bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override
const char * c_str() const
void ForwardTimeStep(int t, double *output_line)
bool DeSerialize(bool training, TFile *fp)
void ResizeNoInit(int size1, int size2, int pad=0)
StaticShape OutputShape(const StaticShape &input_shape) const override
bool Serialize(bool training, TFile *fp) const
void Print(int num) const
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
FullyConnected(const STRING &name, int ni, int no, NetworkType type)
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
bool Serialize(TFile *fp) const override
void ReadTimeStep(int t, double *output) const
void set_loss_type(LossType value)
int RemapOutputs(const std::vector< int > &code_map)
void ConvertToInt() override
void WriteTimeStep(int t, const double *input)
void MatrixDotVector(const double *u, double *v) const
void Resize(const NetworkIO &src, int num_features)
void init_to_size(int size, const T &t)
void VectorDotMatrix(const double *u, double *v) const
DLLSYM void tprintf(const char *format,...)
virtual bool Serialize(TFile *fp) const
void DisplayBackward(const NetworkIO &matrix)
void DebugWeights() override
void set_depth(int value)