tesseract  5.0.0-alpha-619-ge9db
tesseract::WeightMatrix Class Reference

#include <weightmatrix.h>

Public Member Functions

 WeightMatrix ()
 
int InitWeightsFloat (int no, int ni, bool use_adam, float weight_range, TRand *randomizer)
 
int RemapOutputs (const std::vector< int > &code_map)
 
void ConvertToInt ()
 
int RoundInputs (int size) const
 
bool is_int_mode () const
 
int NumOutputs () const
 
const double * GetWeights (int index) const
 
double GetDW (int i, int j) const
 
void InitBackward ()
 
bool Serialize (bool training, TFile *fp) const
 
bool DeSerialize (bool training, TFile *fp)
 
bool DeSerializeOld (bool training, TFile *fp)
 
void MatrixDotVector (const double *u, double *v) const
 
void MatrixDotVector (const int8_t *u, double *v) const
 
void MultiplyAccumulate (const double *v, double *inout)
 
void VectorDotMatrix (const double *u, double *v) const
 
void SumOuterTransposed (const TransposedArray &u, const TransposedArray &v, bool parallel)
 
void Update (double learning_rate, double momentum, double adam_beta, int num_samples)
 
void AddDeltas (const WeightMatrix &other)
 
void CountAlternators (const WeightMatrix &other, double *same, double *changed) const
 
void Debug2D (const char *msg)
 

Static Public Member Functions

static void FloatToDouble (const GENERIC_2D_ARRAY< float > &wf, GENERIC_2D_ARRAY< double > *wd)
 

Detailed Description

Definition at line 65 of file weightmatrix.h.

Constructor & Destructor Documentation

◆ WeightMatrix()

tesseract::WeightMatrix::WeightMatrix ( )
inline

Definition at line 67 of file weightmatrix.h.

67 : int_mode_(false), use_adam_(false) {}

Member Function Documentation

◆ AddDeltas()

void tesseract::WeightMatrix::AddDeltas ( const WeightMatrix other)

Definition at line 337 of file weightmatrix.cpp.

337  {
338  assert(dw_.dim1() == other.dw_.dim1());
339  assert(dw_.dim2() == other.dw_.dim2());
340  dw_ += other.dw_;
341 }

◆ ConvertToInt()

void tesseract::WeightMatrix::ConvertToInt ( )

Definition at line 125 of file weightmatrix.cpp.

125  {
126  wi_.ResizeNoInit(wf_.dim1(), wf_.dim2());
127  scales_.init_to_size(wi_.dim1(), 0.0);
128  int dim2 = wi_.dim2();
129  for (int t = 0; t < wi_.dim1(); ++t) {
130  double* f_line = wf_[t];
131  int8_t* i_line = wi_[t];
132  double max_abs = 0.0;
133  for (int f = 0; f < dim2; ++f) {
134  double abs_val = fabs(f_line[f]);
135  if (abs_val > max_abs) max_abs = abs_val;
136  }
137  double scale = max_abs / INT8_MAX;
138  scales_[t] = scale;
139  if (scale == 0.0) scale = 1.0;
140  for (int f = 0; f < dim2; ++f) {
141  i_line[f] = IntCastRounded(f_line[f] / scale);
142  }
143  }
144  wf_.Resize(1, 1, 0.0);
145  int_mode_ = true;
147  IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
148  }
149 }

◆ CountAlternators()

void tesseract::WeightMatrix::CountAlternators ( const WeightMatrix other,
double *  same,
double *  changed 
) const

Definition at line 346 of file weightmatrix.cpp.

347  {
348  int num_outputs = updates_.dim1();
349  int num_inputs = updates_.dim2();
350  assert(num_outputs == other.updates_.dim1());
351  assert(num_inputs == other.updates_.dim2());
352  for (int i = 0; i < num_outputs; ++i) {
353  const double* this_i = updates_[i];
354  const double* other_i = other.updates_[i];
355  for (int j = 0; j < num_inputs; ++j) {
356  double product = this_i[j] * other_i[j];
357  if (product < 0.0)
358  *changed -= product;
359  else
360  *same += product;
361  }
362  }
363 }

◆ Debug2D()

void tesseract::WeightMatrix::Debug2D ( const char *  msg)

Definition at line 377 of file weightmatrix.cpp.

377  {
378  STATS histogram(0, kHistogramBuckets);
379  if (int_mode_) {
380  for (int i = 0; i < wi_.dim1(); ++i) {
381  for (int j = 0; j < wi_.dim2(); ++j) {
382  HistogramWeight(wi_[i][j] * scales_[i], &histogram);
383  }
384  }
385  } else {
386  for (int i = 0; i < wf_.dim1(); ++i) {
387  for (int j = 0; j < wf_.dim2(); ++j) {
388  HistogramWeight(wf_[i][j], &histogram);
389  }
390  }
391  }
392  tprintf("%s\n", msg);
393  histogram.print();
394 }

◆ DeSerialize()

bool tesseract::WeightMatrix::DeSerialize ( bool  training,
TFile fp 
)

Definition at line 191 of file weightmatrix.cpp.

191  {
192  uint8_t mode;
193  if (!fp->DeSerialize(&mode)) return false;
194  int_mode_ = (mode & kInt8Flag) != 0;
195  use_adam_ = (mode & kAdamFlag) != 0;
196  if ((mode & kDoubleFlag) == 0) return DeSerializeOld(training, fp);
197  if (int_mode_) {
198  if (!wi_.DeSerialize(fp)) return false;
199  if (!scales_.DeSerialize(fp)) return false;
201  IntSimdMatrix::intSimdMatrix->Init(wi_, shaped_w_);
202  }
203  } else {
204  if (!wf_.DeSerialize(fp)) return false;
205  if (training) {
206  InitBackward();
207  if (!updates_.DeSerialize(fp)) return false;
208  if (use_adam_ && !dw_sq_sum_.DeSerialize(fp)) return false;
209  }
210  }
211  return true;
212 }

◆ DeSerializeOld()

bool tesseract::WeightMatrix::DeSerializeOld ( bool  training,
TFile fp 
)

Definition at line 216 of file weightmatrix.cpp.

216  {
217  GENERIC_2D_ARRAY<float> float_array;
218  if (int_mode_) {
219  if (!wi_.DeSerialize(fp)) return false;
220  GenericVector<float> old_scales;
221  if (!old_scales.DeSerialize(fp)) return false;
222  scales_.resize_no_init(old_scales.size());
223  for (int i = 0; i < old_scales.size(); ++i) scales_[i] = old_scales[i];
224  } else {
225  if (!float_array.DeSerialize(fp)) return false;
226  FloatToDouble(float_array, &wf_);
227  }
228  if (training) {
229  InitBackward();
230  if (!float_array.DeSerialize(fp)) return false;
231  FloatToDouble(float_array, &updates_);
232  // Errs was only used in int training, which is now dead.
233  if (!float_array.DeSerialize(fp)) return false;
234  }
235  return true;
236 }

◆ FloatToDouble()

void tesseract::WeightMatrix::FloatToDouble ( const GENERIC_2D_ARRAY< float > &  wf,
GENERIC_2D_ARRAY< double > *  wd 
)
static

Definition at line 399 of file weightmatrix.cpp.

400  {
401  int dim1 = wf.dim1();
402  int dim2 = wf.dim2();
403  wd->ResizeNoInit(dim1, dim2);
404  for (int i = 0; i < dim1; ++i) {
405  const float* wfi = wf[i];
406  double* wdi = (*wd)[i];
407  for (int j = 0; j < dim2; ++j) wdi[j] = static_cast<double>(wfi[j]);
408  }
409 }

◆ GetDW()

double tesseract::WeightMatrix::GetDW ( int  i,
int  j 
) const
inline

Definition at line 105 of file weightmatrix.h.

105 { return dw_(i, j); }

◆ GetWeights()

const double* tesseract::WeightMatrix::GetWeights ( int  index) const
inline

Definition at line 103 of file weightmatrix.h.

103 { return wf_[index]; }

◆ InitBackward()

void tesseract::WeightMatrix::InitBackward ( )

Definition at line 153 of file weightmatrix.cpp.

153  {
154  int no = int_mode_ ? wi_.dim1() : wf_.dim1();
155  int ni = int_mode_ ? wi_.dim2() : wf_.dim2();
156  dw_.Resize(no, ni, 0.0);
157  updates_.Resize(no, ni, 0.0);
158  wf_t_.Transpose(wf_);
159  if (use_adam_) dw_sq_sum_.Resize(no, ni, 0.0);
160 }

◆ InitWeightsFloat()

int tesseract::WeightMatrix::InitWeightsFloat ( int  no,
int  ni,
bool  use_adam,
float  weight_range,
TRand randomizer 
)

Definition at line 76 of file weightmatrix.cpp.

77  {
78  int_mode_ = false;
79  wf_.Resize(no, ni, 0.0);
80  if (randomizer != nullptr) {
81  for (int i = 0; i < no; ++i) {
82  for (int j = 0; j < ni; ++j) {
83  wf_[i][j] = randomizer->SignedRand(weight_range);
84  }
85  }
86  }
87  use_adam_ = use_adam;
88  InitBackward();
89  return ni * no;
90 }

◆ is_int_mode()

bool tesseract::WeightMatrix::is_int_mode ( ) const
inline

Definition at line 98 of file weightmatrix.h.

98  {
99  return int_mode_;
100  }

◆ MatrixDotVector() [1/2]

void tesseract::WeightMatrix::MatrixDotVector ( const double *  u,
double *  v 
) const

Definition at line 243 of file weightmatrix.cpp.

243  {
244  assert(!int_mode_);
245  MatrixDotVectorInternal(wf_, true, false, u, v);
246 }

◆ MatrixDotVector() [2/2]

void tesseract::WeightMatrix::MatrixDotVector ( const int8_t *  u,
double *  v 
) const

Definition at line 248 of file weightmatrix.cpp.

248  {
249  assert(int_mode_);
252  wi_.dim1(), wi_.dim2(), &shaped_w_[0], &scales_[0], u, v);
253  } else {
254  IntSimdMatrix::MatrixDotVector(wi_, scales_, u, v);
255  }
256 }

◆ MultiplyAccumulate()

void tesseract::WeightMatrix::MultiplyAccumulate ( const double *  v,
double *  inout 
)

Definition at line 260 of file weightmatrix.cpp.

260  {
261  assert(!int_mode_);
262  assert(wf_.dim1() == 1);
263  int n = wf_.dim2();
264  const double* u = wf_[0];
265  for (int i = 0; i < n; ++i) {
266  inout[i] += u[i] * v[i];
267  }
268 }

◆ NumOutputs()

int tesseract::WeightMatrix::NumOutputs ( ) const
inline

Definition at line 101 of file weightmatrix.h.

101 { return int_mode_ ? wi_.dim1() : wf_.dim1(); }

◆ RemapOutputs()

int tesseract::WeightMatrix::RemapOutputs ( const std::vector< int > &  code_map)

Definition at line 97 of file weightmatrix.cpp.

97  {
98  GENERIC_2D_ARRAY<double> old_wf(wf_);
99  int old_no = wf_.dim1();
100  int new_no = code_map.size();
101  int ni = wf_.dim2();
102  std::vector<double> means(ni, 0.0);
103  for (int c = 0; c < old_no; ++c) {
104  const double* weights = wf_[c];
105  for (int i = 0; i < ni; ++i) means[i] += weights[i];
106  }
107  for (double& mean : means) mean /= old_no;
108  wf_.ResizeNoInit(new_no, ni);
109  InitBackward();
110  for (int dest = 0; dest < new_no; ++dest) {
111  int src = code_map[dest];
112  const double* src_data = src >= 0 ? old_wf[src] : means.data();
113  memcpy(wf_[dest], src_data, ni * sizeof(*src_data));
114  }
115  return ni * new_no;
116 }

◆ RoundInputs()

int tesseract::WeightMatrix::RoundInputs ( int  size) const
inline

Definition at line 92 of file weightmatrix.h.

92  {
93  if (!int_mode_ || !IntSimdMatrix::intSimdMatrix) return size;
95  }

◆ Serialize()

bool tesseract::WeightMatrix::Serialize ( bool  training,
TFile fp 
) const

Definition at line 172 of file weightmatrix.cpp.

172  {
173  // For backward compatibility, add kDoubleFlag to mode to indicate the doubles
174  // format, without errs, so we can detect and read old format weight matrices.
175  uint8_t mode =
176  (int_mode_ ? kInt8Flag : 0) | (use_adam_ ? kAdamFlag : 0) | kDoubleFlag;
177  if (!fp->Serialize(&mode)) return false;
178  if (int_mode_) {
179  if (!wi_.Serialize(fp)) return false;
180  if (!scales_.Serialize(fp)) return false;
181  } else {
182  if (!wf_.Serialize(fp)) return false;
183  if (training && !updates_.Serialize(fp)) return false;
184  if (training && use_adam_ && !dw_sq_sum_.Serialize(fp)) return false;
185  }
186  return true;
187 }

◆ SumOuterTransposed()

void tesseract::WeightMatrix::SumOuterTransposed ( const TransposedArray u,
const TransposedArray v,
bool  parallel 
)

Definition at line 284 of file weightmatrix.cpp.

286  {
287  assert(!int_mode_);
288  int num_outputs = dw_.dim1();
289  assert(u.dim1() == num_outputs);
290  assert(u.dim2() == v.dim2());
291  int num_inputs = dw_.dim2() - 1;
292  int num_samples = u.dim2();
293  // v is missing the last element in dim1.
294  assert(v.dim1() == num_inputs);
295 #ifdef _OPENMP
296 #pragma omp parallel for num_threads(4) if (in_parallel)
297 #endif
298  for (int i = 0; i < num_outputs; ++i) {
299  double* dwi = dw_[i];
300  const double* ui = u[i];
301  for (int j = 0; j < num_inputs; ++j) {
302  dwi[j] = DotProduct(ui, v[j], num_samples);
303  }
304  // The last element of v is missing, presumed 1.0f.
305  double total = 0.0;
306  for (int k = 0; k < num_samples; ++k) total += ui[k];
307  dwi[num_inputs] = total;
308  }
309 }

◆ Update()

void tesseract::WeightMatrix::Update ( double  learning_rate,
double  momentum,
double  adam_beta,
int  num_samples 
)

Definition at line 314 of file weightmatrix.cpp.

315  {
316  assert(!int_mode_);
317  if (use_adam_ && num_samples > 0 && num_samples < kAdamCorrectionIterations) {
318  learning_rate *= sqrt(1.0 - pow(adam_beta, num_samples));
319  learning_rate /= 1.0 - pow(momentum, num_samples);
320  }
321  if (use_adam_ && num_samples > 0 && momentum > 0.0) {
322  dw_sq_sum_.SumSquares(dw_, adam_beta);
323  dw_ *= learning_rate * (1.0 - momentum);
324  updates_ *= momentum;
325  updates_ += dw_;
326  wf_.AdamUpdate(updates_, dw_sq_sum_, learning_rate * kAdamEpsilon);
327  } else {
328  dw_ *= learning_rate;
329  updates_ += dw_;
330  if (momentum > 0.0) wf_ += updates_;
331  if (momentum >= 0.0) updates_ *= momentum;
332  }
333  wf_t_.Transpose(wf_);
334 }

◆ VectorDotMatrix()

void tesseract::WeightMatrix::VectorDotMatrix ( const double *  u,
double *  v 
) const

Definition at line 274 of file weightmatrix.cpp.

274  {
275  assert(!int_mode_);
276  MatrixDotVectorInternal(wf_t_, false, true, u, v);
277 }

The documentation for this class was generated from the following files:
tesseract::kAdamEpsilon
const double kAdamEpsilon
Definition: weightmatrix.cpp:37
tesseract::kHistogramBuckets
const int kHistogramBuckets
Definition: weightmatrix.cpp:367
tesseract::IntSimdMatrix::Init
void Init(const GENERIC_2D_ARRAY< int8_t > &w, std::vector< int8_t > &shaped_w) const
Definition: intsimdmatrix.cpp:29
GENERIC_2D_ARRAY::DeSerialize
bool DeSerialize(bool swap, FILE *fp)
Definition: matrix.h:160
GenericVector::Serialize
bool Serialize(FILE *fp) const
Definition: genericvector.h:929
IntCastRounded
int IntCastRounded(double x)
Definition: helpers.h:173
tesseract::kAdamCorrectionIterations
const int kAdamCorrectionIterations
Definition: weightmatrix.cpp:35
tesseract::IntSimdMatrix::matrixDotVectorFunction
MatrixDotVectorFunction matrixDotVectorFunction
Definition: intsimdmatrix.h:103
GENERIC_2D_ARRAY< float >
tesseract::WeightMatrix::InitBackward
void InitBackward()
Definition: weightmatrix.cpp:153
GENERIC_2D_ARRAY::dim2
int dim2() const
Definition: matrix.h:206
tesseract::TransposedArray::Transpose
void Transpose(const GENERIC_2D_ARRAY< double > &input)
Definition: weightmatrix.cpp:62
GenericVector::DeSerialize
bool DeSerialize(bool swap, FILE *fp)
Definition: genericvector.h:954
tesseract::kAdamFlag
const int kAdamFlag
Definition: weightmatrix.cpp:165
GENERIC_2D_ARRAY::ResizeNoInit
void ResizeNoInit(int size1, int size2, int pad=0)
Definition: matrix.h:90
tesseract::DotProduct
DotProductFunction DotProduct
Definition: simddetect.cpp:50
GenericVector::resize_no_init
void resize_no_init(int size)
Definition: genericvector.h:65
tesseract::IntSimdMatrix::intSimdMatrix
static const IntSimdMatrix * intSimdMatrix
Definition: intsimdmatrix.h:116
tesseract::IntSimdMatrix::RoundInputs
int RoundInputs(int size) const
Definition: intsimdmatrix.h:69
tesseract::WeightMatrix::DeSerializeOld
bool DeSerializeOld(bool training, TFile *fp)
Definition: weightmatrix.cpp:216
STATS
Definition: statistc.h:30
GenericVector< float >
GENERIC_2D_ARRAY::Serialize
bool Serialize(FILE *fp) const
Definition: matrix.h:143
tesseract::IntSimdMatrix::MatrixDotVector
static void MatrixDotVector(const GENERIC_2D_ARRAY< int8_t > &w, const GenericVector< double > &scales, const int8_t *u, double *v)
Definition: intsimdmatrix.cpp:79
GENERIC_2D_ARRAY::SumSquares
void SumSquares(const GENERIC_2D_ARRAY< T > &src, const T &decay_factor)
Definition: matrix.h:367
tesstrain_utils.dest
dest
Definition: tesstrain_utils.py:139
GENERIC_2D_ARRAY::Resize
void Resize(int size1, int size2, const T &empty)
Definition: matrix.h:104
GenericVector::init_to_size
void init_to_size(int size, const T &t)
Definition: genericvector.h:706
tprintf
DLLSYM void tprintf(const char *format,...)
Definition: tprintf.cpp:34
GENERIC_2D_ARRAY::AdamUpdate
void AdamUpdate(const GENERIC_2D_ARRAY< T > &sum, const GENERIC_2D_ARRAY< T > &sqsum, const T &epsilon)
Definition: matrix.h:378
tesseract::WeightMatrix::FloatToDouble
static void FloatToDouble(const GENERIC_2D_ARRAY< float > &wf, GENERIC_2D_ARRAY< double > *wd)
Definition: weightmatrix.cpp:399
GenericVector::size
int size() const
Definition: genericvector.h:71
tesseract::kDoubleFlag
const int kDoubleFlag
Definition: weightmatrix.cpp:169
GENERIC_2D_ARRAY::dim1
int dim1() const
Definition: matrix.h:205
tesseract::kInt8Flag
const int kInt8Flag
Definition: weightmatrix.cpp:163