tesseract  5.0.0-alpha-619-ge9db
intsimdmatrixavx2.cpp
Go to the documentation of this file.
1 // File: intsimdmatrixavx2.cpp
3 // Description: matrix-vector product for 8-bit data on avx2.
4 // Author: Ray Smith
5 // Created: Fri Aug 04 13:26:20 PST 2017
6 //
7 // (C) Copyright 2017, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #if !defined(__AVX2__)
20 #error Implementation only for AVX2 capable architectures
21 #endif
22 
23 #include "intsimdmatrix.h"
24 
25 #include <immintrin.h>
26 #include <cstdint>
27 #include <algorithm>
28 #include <vector>
29 
30 namespace tesseract {
31 
32 // Number of outputs held in each register. 8 x 32 bit ints.
33 constexpr int kNumOutputsPerRegister = 8;
34 // Maximum number of registers that we will use.
35 constexpr int kMaxOutputRegisters = 8;
36 // Number of inputs in the inputs register.
37 constexpr int kNumInputsPerRegister = 32;
38 // Number of inputs in each weight group.
39 constexpr int kNumInputsPerGroup = 4;
40 // Number of groups of inputs to be broadcast.
42 
43 // Functions to compute part of a matrix.vector multiplication. The weights
44 // are in a very specific order (see above) in w, which is multiplied by
45 // u of length num_in, to produce output v after scaling the integer results
46 // by the corresponding member of scales.
47 // The amount of w and scales consumed is fixed and not available to the
48 // caller. The number of outputs written to v will be at most num_out.
49 
50 // Computes one set of 4x8 products of inputs and weights, adding to result.
51 // Horizontally adds 4 adjacent results, making 8x32-bit results.
52 // rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
53 // Note that wi must previously have been re-organized with blocks of 4x8
54 // weights in contiguous memory.
55 // ones is a register of 16x16-bit values all equal to 1.
56 // Note: wi is incremented by the amount of data read.
57 // weights and reps are scratch registers.
58 // This function must be inlined with references in order for the compiler to
59 // correctly use the registers declared in the caller.
60 static inline void MultiplyGroup(const __m256i& rep_input, const __m256i& ones,
61  const int8_t*& wi, __m256i& weights,
62  __m256i& reps, __m256i& result) {
63  // Load a 4x8 block of weights.
64  weights = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(wi));
66  // Normalize the signs on rep_input, weights, so weights is always +ve.
67  reps = _mm256_sign_epi8(rep_input, weights);
68  weights = _mm256_sign_epi8(weights, weights);
69  // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
70  // with adjacent pairs added.
71  weights = _mm256_maddubs_epi16(weights, reps);
72  // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
73  // with adjacent pairs added. What we really want is a horizontal add of
74  // 16+16=32 bit result, but there is no such instruction, so multiply by
75  // 16-bit ones instead. It is probably faster than all the sign-extending,
76  // permuting and adding that would otherwise be required.
77  weights = _mm256_madd_epi16(weights, ones);
78  result = _mm256_add_epi32(result, weights);
79 }
80 
81 // Extracts and converts 8x32-bit results from result, adding the bias from wi
82 // and scaling by scales, before storing in *v. Note that wi, scales and v are
83 // expected to contain 8 consecutive elements or num_out if less.
84 static inline void ExtractResults(__m256i& result, __m256i& shift_id,
85  const int8_t*& wi, const double*& scales,
86  int num_out, double*& v) {
87  for (int out = 0; out < num_out; ++out) {
88 #ifndef _MSC_VER
89  auto res = _mm256_extract_epi32(result, 0);
90 #else
91  // Workaround MSVC's ICE
92  // _mm256_extract_epi32(X, Y) == ((int32_t*)&X)[Y]
93  auto res = ((int32_t*)&result)[0];
94 #endif
95  *v++ = (static_cast<double>(res) / INT8_MAX + *wi++) * *scales++;
96  // Rotate the results in int32_t units, so the next result is ready.
97  result = _mm256_permutevar8x32_epi32(result, shift_id);
98  }
99 }
100 
101 // Computes part of matrix.vector v = Wu. Computes N=64 results.
102 // The weights *must* be arranged so that consecutive reads from wi
103 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
104 // (kNumInputsPerGroup inputs))). After that there must be N consecutive
105 // bias weights, before continuing with any more weights.
106 // u must be padded out with zeros to
107 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
108 static void PartialMatrixDotVector64(const int8_t* wi, const double* scales,
109  const int8_t* u, int num_in, int num_out,
110  double* v) {
111  // Register containing 16-bit ones for horizontal add with 16->32 bit
112  // conversion.
113  __m256i ones =
114  _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
115  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
116  // Initialize all the results to 0.
117  __m256i result0 = _mm256_setzero_si256();
118  __m256i result1 = _mm256_setzero_si256();
119  __m256i result2 = _mm256_setzero_si256();
120  __m256i result3 = _mm256_setzero_si256();
121  __m256i result4 = _mm256_setzero_si256();
122  __m256i result5 = _mm256_setzero_si256();
123  __m256i result6 = _mm256_setzero_si256();
124  __m256i result7 = _mm256_setzero_si256();
125  // Iterate over the input (u), one registerful at a time.
126  for (int j = 0; j < num_in;) {
127  __m256i inputs =
128  _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
129  // Inputs are processed in groups of kNumInputsPerGroup, replicated
130  // kNumInputGroups times.
131  for (int ig = 0; ig < kNumInputGroups && j < num_in;
132  ++ig, j += kNumInputsPerGroup) {
133  // Replicate the low 32 bits (4 inputs) 8 times.
134  __m256i rep_input =
135  _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
136  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
137  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
138  __m256i weights, reps;
139  // Mul-add, with horizontal add of the 4 inputs to each of the results.
140  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
141  MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
142  MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
143  MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
144  MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
145  MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
146  MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
147  MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
148  }
149  }
150  ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
151  ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
152  ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
153  ExtractResults(result3, shift_id, wi, scales, kNumOutputsPerRegister, v);
154  ExtractResults(result4, shift_id, wi, scales, kNumOutputsPerRegister, v);
155  ExtractResults(result5, shift_id, wi, scales, kNumOutputsPerRegister, v);
156  ExtractResults(result6, shift_id, wi, scales, kNumOutputsPerRegister, v);
157  num_out -= kNumOutputsPerRegister * 7;
158  ExtractResults(result7, shift_id, wi, scales,
159  std::min(kNumOutputsPerRegister, num_out), v);
160 }
161 
162 // Computes part of matrix.vector v = Wu. Computes N=32 results.
163 // For details see PartialMatrixDotVector64 with N=32.
164 static void PartialMatrixDotVector32(const int8_t* wi, const double* scales,
165  const int8_t* u, int num_in, int num_out,
166  double* v) {
167  // Register containing 16-bit ones for horizontal add with 16->32 bit
168  // conversion.
169  __m256i ones =
170  _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
171  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
172  // Initialize all the results to 0.
173  __m256i result0 = _mm256_setzero_si256();
174  __m256i result1 = _mm256_setzero_si256();
175  __m256i result2 = _mm256_setzero_si256();
176  __m256i result3 = _mm256_setzero_si256();
177  // Iterate over the input (u), one registerful at a time.
178  for (int j = 0; j < num_in;) {
179  __m256i inputs =
180  _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
181  // Inputs are processed in groups of kNumInputsPerGroup, replicated
182  // kNumInputGroups times.
183  for (int ig = 0; ig < kNumInputGroups && j < num_in;
184  ++ig, j += kNumInputsPerGroup) {
185  // Replicate the low 32 bits (4 inputs) 8 times.
186  __m256i rep_input =
187  _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
188  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
189  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
190  __m256i weights, reps;
191  // Mul-add, with horizontal add of the 4 inputs to each of the results.
192  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
193  MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
194  MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
195  MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
196  }
197  }
198  ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
199  ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
200  ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
201  num_out -= kNumOutputsPerRegister * 3;
202  ExtractResults(result3, shift_id, wi, scales,
203  std::min(kNumOutputsPerRegister, num_out), v);
204 }
205 
206 // Computes part of matrix.vector v = Wu. Computes N=16 results.
207 // For details see PartialMatrixDotVector64 with N=16.
208 static void PartialMatrixDotVector16(const int8_t* wi, const double* scales,
209  const int8_t* u, int num_in, int num_out,
210  double* v) {
211  // Register containing 16-bit ones for horizontal add with 16->32 bit
212  // conversion.
213  __m256i ones =
214  _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
215  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
216  // Initialize all the results to 0.
217  __m256i result0 = _mm256_setzero_si256();
218  __m256i result1 = _mm256_setzero_si256();
219  // Iterate over the input (u), one registerful at a time.
220  for (int j = 0; j < num_in;) {
221  __m256i inputs =
222  _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
223  // Inputs are processed in groups of kNumInputsPerGroup, replicated
224  // kNumInputGroups times.
225  for (int ig = 0; ig < kNumInputGroups && j < num_in;
226  ++ig, j += kNumInputsPerGroup) {
227  // Replicate the low 32 bits (4 inputs) 8 times.
228  __m256i rep_input =
229  _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
230  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
231  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
232  __m256i weights, reps;
233  // Mul-add, with horizontal add of the 4 inputs to each of the results.
234  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
235  MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
236  }
237  }
238  ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
239  num_out -= kNumOutputsPerRegister;
240  ExtractResults(result1, shift_id, wi, scales,
241  std::min(kNumOutputsPerRegister, num_out), v);
242 }
243 
244 // Computes part of matrix.vector v = Wu. Computes N=8 results.
245 // For details see PartialMatrixDotVector64 with N=8.
246 static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
247  const int8_t* u, int num_in, int num_out,
248  double* v) {
249  // Register containing 16-bit ones for horizontal add with 16->32 bit
250  // conversion.
251  __m256i ones =
252  _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
253  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
254  // Initialize all the results to 0.
255  __m256i result0 = _mm256_setzero_si256();
256  // Iterate over the input (u), one registerful at a time.
257  for (int j = 0; j < num_in;) {
258  __m256i inputs =
259  _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
260  // Inputs are processed in groups of kNumInputsPerGroup, replicated
261  // kNumInputGroups times.
262  for (int ig = 0; ig < kNumInputGroups && j < num_in;
263  ++ig, j += kNumInputsPerGroup) {
264  // Replicate the low 32 bits (4 inputs) 8 times.
265  __m256i rep_input =
266  _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
267  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
268  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
269  __m256i weights, reps;
270  // Mul-add, with horizontal add of the 4 inputs to each of the results.
271  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
272  }
273  }
274  ExtractResults(result0, shift_id, wi, scales, num_out, v);
275 }
276 
277 static void matrixDotVector(int dim1, int dim2, const int8_t* wi,
278  const double* scales, const int8_t* u, double* v) {
279  const int num_out = dim1;
280  const int num_in = dim2 - 1;
281  // Each call to a partial_func_ produces group_size outputs, except the
282  // last one, which can produce less.
283  const int rounded_num_in =
285  const int rounded_num_out =
287  int group_size = kNumOutputsPerRegister * kMaxOutputRegisters;
288  int output = 0;
289 
290  int w_step = (rounded_num_in + 1) * group_size;
291 
292  // Run with this group size, until it would produce too much output, then
293  // switch to a smaller size.
294  for (; output + group_size <= rounded_num_out; output += group_size) {
295  PartialMatrixDotVector64(wi, scales, u, rounded_num_in, num_out - output, v);
296  wi += w_step;
297  scales += group_size;
298  v += group_size;
299  }
300  group_size /= 2;
301  w_step /= 2;
302 
303  for (; output + group_size <= rounded_num_out; output += group_size) {
304  PartialMatrixDotVector32(wi, scales, u, rounded_num_in, num_out - output, v);
305  wi += w_step;
306  scales += group_size;
307  v += group_size;
308  }
309  group_size /= 2;
310  w_step /= 2;
311 
312  for (; output + group_size <= rounded_num_out; output += group_size) {
313  PartialMatrixDotVector16(wi, scales, u, rounded_num_in, num_out - output, v);
314  wi += w_step;
315  scales += group_size;
316  v += group_size;
317  }
318  group_size /= 2;
319  w_step /= 2;
320 
321  for (; output + group_size <= rounded_num_out; output += group_size) {
322  PartialMatrixDotVector8(wi, scales, u, rounded_num_in, num_out - output, v);
323  wi += w_step;
324  scales += group_size;
325  v += group_size;
326  }
327 }
328 
329 const IntSimdMatrix IntSimdMatrix::intSimdMatrixAVX2 = {
330  // Function.
331  matrixDotVector,
332  // Number of 32 bit outputs held in each register.
334  // Maximum number of registers that we will use to hold outputs.
336  // Number of 8 bit inputs in the inputs register.
338  // Number of inputs in each weight group.
340 };
341 
342 } // namespace tesseract.
tesseract::kNumInputsPerRegister
constexpr int kNumInputsPerRegister
Definition: intsimdmatrixavx2.cpp:37
tesseract::kNumOutputsPerRegister
constexpr int kNumOutputsPerRegister
Definition: intsimdmatrixavx2.cpp:33
tesseract::IntSimdMatrix::intSimdMatrixAVX2
static const IntSimdMatrix intSimdMatrixAVX2
Definition: intsimdmatrix.h:117
tesseract::kNumInputsPerGroup
constexpr int kNumInputsPerGroup
Definition: intsimdmatrixavx2.cpp:39
tesseract::IntSimdMatrix::Roundup
static int Roundup(int input, int factor)
Definition: intsimdmatrix.h:87
tesseract::kNumInputGroups
constexpr int kNumInputGroups
Definition: intsimdmatrixavx2.cpp:41
tesseract
Definition: baseapi.h:65
tesseract::kMaxOutputRegisters
constexpr int kMaxOutputRegisters
Definition: intsimdmatrixavx2.cpp:35
intsimdmatrix.h