tesseract
5.0.0-alpha-619-ge9db
|
#include <fullyconnected.h>
|
| FullyConnected (const STRING &name, int ni, int no, NetworkType type) |
|
| ~FullyConnected () override=default |
|
StaticShape | OutputShape (const StaticShape &input_shape) const override |
|
STRING | spec () const override |
|
void | ChangeType (NetworkType type) |
|
void | SetEnableTraining (TrainingState state) override |
|
int | InitWeights (float range, TRand *randomizer) override |
|
int | RemapOutputs (int old_no, const std::vector< int > &code_map) override |
|
void | ConvertToInt () override |
|
void | DebugWeights () override |
|
bool | Serialize (TFile *fp) const override |
|
bool | DeSerialize (TFile *fp) override |
|
void | Forward (bool debug, const NetworkIO &input, const TransposedArray *input_transpose, NetworkScratch *scratch, NetworkIO *output) override |
|
void | SetupForward (const NetworkIO &input, const TransposedArray *input_transpose) |
|
void | ForwardTimeStep (int t, double *output_line) |
|
void | ForwardTimeStep (const double *d_input, int t, double *output_line) |
|
void | ForwardTimeStep (const int8_t *i_input, int t, double *output_line) |
|
bool | Backward (bool debug, const NetworkIO &fwd_deltas, NetworkScratch *scratch, NetworkIO *back_deltas) override |
|
void | BackwardTimeStep (const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop) |
|
void | FinishBackward (const TransposedArray &errors_t) |
|
void | Update (float learning_rate, float momentum, float adam_beta, int num_samples) override |
|
void | CountAlternators (const Network &other, double *same, double *changed) const override |
|
| Network () |
|
| Network (NetworkType type, const STRING &name, int ni, int no) |
|
virtual | ~Network ()=default |
|
NetworkType | type () const |
|
bool | IsTraining () const |
|
bool | needs_to_backprop () const |
|
int | num_weights () const |
|
int | NumInputs () const |
|
int | NumOutputs () const |
|
virtual StaticShape | InputShape () const |
|
const STRING & | name () const |
|
bool | TestFlag (NetworkFlags flag) const |
|
virtual bool | IsPlumbingType () const |
|
virtual void | SetNetworkFlags (uint32_t flags) |
|
virtual void | SetRandomizer (TRand *randomizer) |
|
virtual bool | SetupNeedsBackprop (bool needs_backprop) |
|
virtual int | XScaleFactor () const |
|
virtual void | CacheXScaleFactor (int factor) |
|
void | DisplayForward (const NetworkIO &matrix) |
|
void | DisplayBackward (const NetworkIO &matrix) |
|
Definition at line 28 of file fullyconnected.h.
◆ FullyConnected()
tesseract::FullyConnected::FullyConnected |
( |
const STRING & |
name, |
|
|
int |
ni, |
|
|
int |
no, |
|
|
NetworkType |
type |
|
) |
| |
◆ ~FullyConnected()
tesseract::FullyConnected::~FullyConnected |
( |
| ) |
|
|
overridedefault |
◆ Backward()
Implements tesseract::Network.
Definition at line 220 of file fullyconnected.cpp.
224 back_deltas->Resize(fwd_deltas,
ni_);
227 for (
int i = 0; i <
kNumThreads; ++i) errors[i].Init(
no_, scratch);
231 for (
int i = 0; i <
kNumThreads; ++i) temp_backprops[i].Init(
ni_, scratch);
233 int width = fwd_deltas.Width();
234 NetworkScratch::GradientStore errors_t;
235 errors_t.Init(
no_, width, scratch);
237 #pragma omp parallel for num_threads(kNumThreads)
238 for (
int t = 0; t < width; ++t) {
239 int thread_id = omp_get_thread_num();
241 for (
int t = 0; t < width; ++t) {
244 double* backprop =
nullptr;
246 double* curr_errors = errors[thread_id];
248 if (backprop !=
nullptr) {
249 back_deltas->WriteTimeStep(t, backprop);
254 back_deltas->ZeroInvalidElements();
257 back_deltas->Print(10);
◆ BackwardTimeStep()
void tesseract::FullyConnected::BackwardTimeStep |
( |
const NetworkIO & |
fwd_deltas, |
|
|
int |
t, |
|
|
double * |
curr_errors, |
|
|
TransposedArray * |
errors_t, |
|
|
double * |
backprop |
|
) |
| |
Definition at line 264 of file fullyconnected.cpp.
280 fwd_deltas.ReadTimeStep(t, curr_errors);
282 ASSERT_HOST(
"Invalid fully-connected type!" ==
nullptr);
285 errors_t->WriteStrided(t, curr_errors);
◆ ChangeType()
void tesseract::FullyConnected::ChangeType |
( |
NetworkType |
type | ) |
|
|
inline |
◆ ConvertToInt()
void tesseract::FullyConnected::ConvertToInt |
( |
| ) |
|
|
overridevirtual |
◆ CountAlternators()
void tesseract::FullyConnected::CountAlternators |
( |
const Network & |
other, |
|
|
double * |
same, |
|
|
double * |
changed |
|
) |
| const |
|
overridevirtual |
◆ DebugWeights()
void tesseract::FullyConnected::DebugWeights |
( |
| ) |
|
|
overridevirtual |
◆ DeSerialize()
bool tesseract::FullyConnected::DeSerialize |
( |
TFile * |
fp | ) |
|
|
overridevirtual |
◆ FinishBackward()
void tesseract::FullyConnected::FinishBackward |
( |
const TransposedArray & |
errors_t | ) |
|
◆ Forward()
Implements tesseract::Network.
Definition at line 118 of file fullyconnected.cpp.
121 int width = input.Width();
123 output->ResizeFloat(input,
no_);
125 output->Resize(input,
no_);
132 temp_lines[i].Init(
no_, scratch);
133 curr_input[i].Init(
ni_, scratch);
136 #pragma omp parallel for num_threads(kNumThreads)
137 for (
int t = 0; t < width; ++t) {
139 int thread_id = omp_get_thread_num();
141 for (
int t = 0; t < width; ++t) {
145 double* temp_line = temp_lines[thread_id];
146 if (input.int_mode()) {
149 input.ReadTimeStep(t, curr_input[thread_id]);
152 output->WriteTimeStep(t, temp_line);
163 output->ZeroInvalidElements();
◆ ForwardTimeStep() [1/3]
void tesseract::FullyConnected::ForwardTimeStep |
( |
const double * |
d_input, |
|
|
int |
t, |
|
|
double * |
output_line |
|
) |
| |
◆ ForwardTimeStep() [2/3]
void tesseract::FullyConnected::ForwardTimeStep |
( |
const int8_t * |
i_input, |
|
|
int |
t, |
|
|
double * |
output_line |
|
) |
| |
◆ ForwardTimeStep() [3/3]
void tesseract::FullyConnected::ForwardTimeStep |
( |
int |
t, |
|
|
double * |
output_line |
|
) |
| |
Definition at line 184 of file fullyconnected.cpp.
186 FuncInplace<GFunc>(
no_, output_line);
188 FuncInplace<FFunc>(
no_, output_line);
190 FuncInplace<ClipFFunc>(
no_, output_line);
192 FuncInplace<ClipGFunc>(
no_, output_line);
194 FuncInplace<Relu>(
no_, output_line);
198 ASSERT_HOST(
"Invalid fully-connected type!" ==
nullptr);
◆ InitWeights()
int tesseract::FullyConnected::InitWeights |
( |
float |
range, |
|
|
TRand * |
randomizer |
|
) |
| |
|
overridevirtual |
◆ OutputShape()
◆ RemapOutputs()
int tesseract::FullyConnected::RemapOutputs |
( |
int |
old_no, |
|
|
const std::vector< int > & |
code_map |
|
) |
| |
|
overridevirtual |
◆ Serialize()
bool tesseract::FullyConnected::Serialize |
( |
TFile * |
fp | ) |
const |
|
overridevirtual |
◆ SetEnableTraining()
void tesseract::FullyConnected::SetEnableTraining |
( |
TrainingState |
state | ) |
|
|
overridevirtual |
◆ SetupForward()
◆ spec()
STRING tesseract::FullyConnected::spec |
( |
| ) |
const |
|
inlineoverridevirtual |
◆ Update()
void tesseract::FullyConnected::Update |
( |
float |
learning_rate, |
|
|
float |
momentum, |
|
|
float |
adam_beta, |
|
|
int |
num_samples |
|
) |
| |
|
overridevirtual |
◆ acts_
◆ external_source_
◆ int_mode_
bool tesseract::FullyConnected::int_mode_ |
|
protected |
◆ source_t_
◆ weights_
The documentation for this class was generated from the following files:
const TransposedArray * external_source_
TransposedArray source_t_
void CountAlternators(const WeightMatrix &other, double *same, double *changed) const
void FinishBackward(const TransposedArray &errors_t)
virtual void SetRandomizer(TRand *randomizer)
void add_str_int(const char *str, int number)
void DisplayForward(const NetworkIO &matrix)
void ZeroInvalidElements()
void SetupForward(const NetworkIO &input, const TransposedArray *input_transpose)
void SumOuterTransposed(const TransposedArray &u, const TransposedArray &v, bool parallel)
void WriteStrided(int t, const float *data)
void Update(double learning_rate, double momentum, double adam_beta, int num_samples)
void FuncMultiply(const NetworkIO &v_io, int t, double *product)
void Debug2D(const char *msg)
bool TestFlag(NetworkFlags flag) const
void BackwardTimeStep(const NetworkIO &fwd_deltas, int t, double *curr_errors, TransposedArray *errors_t, double *backprop)
void SoftmaxInPlace(int n, T *inout)
const char * c_str() const
void ForwardTimeStep(int t, double *output_line)
bool DeSerialize(bool training, TFile *fp)
void ResizeNoInit(int size1, int size2, int pad=0)
bool Serialize(bool training, TFile *fp) const
void CopyTimeStepFrom(int dest_t, const NetworkIO &src, int src_t)
int InitWeightsFloat(int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
int RemapOutputs(const std::vector< int > &code_map)
STRING spec() const override
void MatrixDotVector(const double *u, double *v) const
void Resize(const NetworkIO &src, int num_features)
const STRING & name() const
void init_to_size(int size, const T &t)
void VectorDotMatrix(const double *u, double *v) const
DLLSYM void tprintf(const char *format,...)
virtual bool Serialize(TFile *fp) const
void DisplayBackward(const NetworkIO &matrix)