tesseract  4.0.0-1-g2a2b
intsimdmatrixavx2.cpp
Go to the documentation of this file.
1 // File: intsimdmatrixavx2.cpp
3 // Description: matrix-vector product for 8-bit data on avx2.
4 // Author: Ray Smith
5 // Created: Fri Aug 04 13:26:20 PST 2017
6 //
7 // (C) Copyright 2017, Google Inc.
8 // Licensed under the Apache License, Version 2.0 (the "License");
9 // you may not use this file except in compliance with the License.
10 // You may obtain a copy of the License at
11 // http://www.apache.org/licenses/LICENSE-2.0
12 // Unless required by applicable law or agreed to in writing, software
13 // distributed under the License is distributed on an "AS IS" BASIS,
14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 // See the License for the specific language governing permissions and
16 // limitations under the License.
18 
19 #include "intsimdmatrixavx2.h"
20 
21 #ifdef __AVX2__
22 #include <immintrin.h>
23 #include <cstdint>
24 #include <algorithm>
25 #include <vector>
26 
27 namespace tesseract {
28 
29 // Number of outputs held in each register. 8 x 32 bit ints.
30 constexpr int kNumOutputsPerRegister = 8;
31 // Maximum number of registers that we will use.
32 constexpr int kMaxOutputRegisters = 8;
33 // Number of inputs in the inputs register.
34 constexpr int kNumInputsPerRegister = 32;
35 // Number of inputs in each weight group.
36 constexpr int kNumInputsPerGroup = 4;
37 // Number of groups of inputs to be broadcast.
38 constexpr int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
39 
40 // Computes one set of 4x8 products of inputs and weights, adding to result.
41 // Horizontally adds 4 adjacent results, making 8x32-bit results.
42 // rep_input is assumed to be an 8x replicated set of 4x8-bit signed integers.
43 // Note that wi must previously have been re-organized with blocks of 4x8
44 // weights in contiguous memory.
45 // ones is a register of 16x16-bit values all equal to 1.
46 // Note: wi is incremented by the amount of data read.
47 // weights and reps are scratch registers.
48 // This function must be inlined with references in order for the compiler to
49 // correctly use the registers declared in the caller.
50 inline void MultiplyGroup(const __m256i& rep_input, const __m256i& ones,
51  const int8_t*& wi, __m256i& weights, __m256i& reps,
52  __m256i& result) {
53  // Load a 4x8 block of weights.
54  weights = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(wi));
55  wi += kNumInputsPerRegister;
56  // Normalize the signs on rep_input, weights, so weights is always +ve.
57  reps = _mm256_sign_epi8(rep_input, weights);
58  weights = _mm256_sign_epi8(weights, weights);
59  // Multiply 32x8-bit reps by 32x8-bit weights to make 16x16-bit results,
60  // with adjacent pairs added.
61  weights = _mm256_maddubs_epi16(weights, reps);
62  // Multiply 16x16-bit result by 16x16-bit ones to make 8x32-bit results,
63  // with adjacent pairs added. What we really want is a horizontal add of
64  // 16+16=32 bit result, but there is no such instruction, so multiply by
65  // 16-bit ones instead. It is probably faster than all the sign-extending,
66  // permuting and adding that would otherwise be required.
67  weights = _mm256_madd_epi16(weights, ones);
68  result = _mm256_add_epi32(result, weights);
69 }
70 
71 // Extracts and converts 8x32-bit results from result, adding the bias from wi
72 // and scaling by scales, before storing in *v. Note that wi, scales and v are
73 // expected to contain 8 consecutive elements or num_out if less.
74 inline void ExtractResults(__m256i& result, __m256i& shift_id,
75  const int8_t*& wi, const double*& scales,
76  int num_out, double*& v) {
77  for (int out = 0; out < num_out; ++out) {
78  int32_t res =
79 #ifndef _MSC_VER
80  _mm256_extract_epi32(result, 0)
81 #else
82  // Workaround MSVC's ICE
83  // _mm256_extract_epi32(X, Y) == ((int32_t*)&X)[Y]
84  ((int32_t*)&result)[0]
85 #endif
86  ;
87  *v++ = (static_cast<double>(res) / INT8_MAX + *wi++) * *scales++;
88  // Rotate the results in int32_t units, so the next result is ready.
89  result = _mm256_permutevar8x32_epi32(result, shift_id);
90  }
91 }
92 
93 // Computes part of matrix.vector v = Wu. Computes N=64 results.
94 // The weights *must* be arranged so that consecutive reads from wi
95 // provides (num_in/kNumInputsPerGroup groups of (N output dim groups of
96 // (kNumInputsPerGroup inputs))). After that there must be N consecutive
97 // bias weights, before continuing with any more weights.
98 // u must be padded out with zeros to
99 // kNumInputsPerGroup*ceil(num_in/kNumInputsPerGroup) elements.
100 static void PartialMatrixDotVector64(const int8_t* wi, const double* scales,
101  const int8_t* u, int num_in, int num_out,
102  double* v) {
103  // Register containing 16-bit ones for horizontal add with 16->32 bit
104  // conversion.
105  __m256i ones =
106  _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
107  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
108  // Initialize all the results to 0.
109  __m256i result0 = _mm256_setzero_si256();
110  __m256i result1 = _mm256_setzero_si256();
111  __m256i result2 = _mm256_setzero_si256();
112  __m256i result3 = _mm256_setzero_si256();
113  __m256i result4 = _mm256_setzero_si256();
114  __m256i result5 = _mm256_setzero_si256();
115  __m256i result6 = _mm256_setzero_si256();
116  __m256i result7 = _mm256_setzero_si256();
117  // Iterate over the input (u), one registerful at a time.
118  for (int j = 0; j < num_in;) {
119  __m256i inputs =
120  _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
121  // Inputs are processed in groups of kNumInputsPerGroup, replicated
122  // kNumInputGroups times.
123  for (int ig = 0; ig < kNumInputGroups && j < num_in;
124  ++ig, j += kNumInputsPerGroup) {
125  // Replicate the low 32 bits (4 inputs) 8 times.
126  __m256i rep_input =
127  _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
128  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
129  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
130  __m256i weights, reps;
131  // Mul-add, with horizontal add of the 4 inputs to each of the results.
132  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
133  MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
134  MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
135  MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
136  MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
137  MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
138  MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
139  MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
140  }
141  }
142  ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
143  ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
144  ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
145  ExtractResults(result3, shift_id, wi, scales, kNumOutputsPerRegister, v);
146  ExtractResults(result4, shift_id, wi, scales, kNumOutputsPerRegister, v);
147  ExtractResults(result5, shift_id, wi, scales, kNumOutputsPerRegister, v);
148  ExtractResults(result6, shift_id, wi, scales, kNumOutputsPerRegister, v);
149  num_out -= kNumOutputsPerRegister * 7;
150  ExtractResults(result7, shift_id, wi, scales,
151  std::min(kNumOutputsPerRegister, num_out), v);
152 }
153 
154 // Computes part of matrix.vector v = Wu. Computes N=32 results.
155 // For details see PartialMatrixDotVector64 with N=32.
156 static void PartialMatrixDotVector32(const int8_t* wi, const double* scales,
157  const int8_t* u, int num_in, int num_out,
158  double* v) {
159  // Register containing 16-bit ones for horizontal add with 16->32 bit
160  // conversion.
161  __m256i ones =
162  _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
163  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
164  // Initialize all the results to 0.
165  __m256i result0 = _mm256_setzero_si256();
166  __m256i result1 = _mm256_setzero_si256();
167  __m256i result2 = _mm256_setzero_si256();
168  __m256i result3 = _mm256_setzero_si256();
169  // Iterate over the input (u), one registerful at a time.
170  for (int j = 0; j < num_in;) {
171  __m256i inputs =
172  _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
173  // Inputs are processed in groups of kNumInputsPerGroup, replicated
174  // kNumInputGroups times.
175  for (int ig = 0; ig < kNumInputGroups && j < num_in;
176  ++ig, j += kNumInputsPerGroup) {
177  // Replicate the low 32 bits (4 inputs) 8 times.
178  __m256i rep_input =
179  _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
180  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
181  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
182  __m256i weights, reps;
183  // Mul-add, with horizontal add of the 4 inputs to each of the results.
184  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
185  MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
186  MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
187  MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
188  }
189  }
190  ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
191  ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
192  ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
193  num_out -= kNumOutputsPerRegister * 3;
194  ExtractResults(result3, shift_id, wi, scales,
195  std::min(kNumOutputsPerRegister, num_out), v);
196 }
197 
198 // Computes part of matrix.vector v = Wu. Computes N=16 results.
199 // For details see PartialMatrixDotVector64 with N=16.
200 static void PartialMatrixDotVector16(const int8_t* wi, const double* scales,
201  const int8_t* u, int num_in, int num_out,
202  double* v) {
203  // Register containing 16-bit ones for horizontal add with 16->32 bit
204  // conversion.
205  __m256i ones =
206  _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
207  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
208  // Initialize all the results to 0.
209  __m256i result0 = _mm256_setzero_si256();
210  __m256i result1 = _mm256_setzero_si256();
211  // Iterate over the input (u), one registerful at a time.
212  for (int j = 0; j < num_in;) {
213  __m256i inputs =
214  _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
215  // Inputs are processed in groups of kNumInputsPerGroup, replicated
216  // kNumInputGroups times.
217  for (int ig = 0; ig < kNumInputGroups && j < num_in;
218  ++ig, j += kNumInputsPerGroup) {
219  // Replicate the low 32 bits (4 inputs) 8 times.
220  __m256i rep_input =
221  _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
222  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
223  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
224  __m256i weights, reps;
225  // Mul-add, with horizontal add of the 4 inputs to each of the results.
226  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
227  MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
228  }
229  }
230  ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
231  num_out -= kNumOutputsPerRegister;
232  ExtractResults(result1, shift_id, wi, scales,
233  std::min(kNumOutputsPerRegister, num_out), v);
234 }
235 
236 // Computes part of matrix.vector v = Wu. Computes N=8 results.
237 // For details see PartialMatrixDotVector64 with N=8.
238 static void PartialMatrixDotVector8(const int8_t* wi, const double* scales,
239  const int8_t* u, int num_in, int num_out,
240  double* v) {
241  // Register containing 16-bit ones for horizontal add with 16->32 bit
242  // conversion.
243  __m256i ones =
244  _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
245  __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
246  // Initialize all the results to 0.
247  __m256i result0 = _mm256_setzero_si256();
248  // Iterate over the input (u), one registerful at a time.
249  for (int j = 0; j < num_in;) {
250  __m256i inputs =
251  _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
252  // Inputs are processed in groups of kNumInputsPerGroup, replicated
253  // kNumInputGroups times.
254  for (int ig = 0; ig < kNumInputGroups && j < num_in;
255  ++ig, j += kNumInputsPerGroup) {
256  // Replicate the low 32 bits (4 inputs) 8 times.
257  __m256i rep_input =
258  _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
259  // Rotate the inputs in groups of 4, so the next 4 inputs are ready.
260  inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
261  __m256i weights, reps;
262  // Mul-add, with horizontal add of the 4 inputs to each of the results.
263  MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
264  }
265  }
266  ExtractResults(result0, shift_id, wi, scales, num_out, v);
267 }
268 #else
269 namespace tesseract {
270 #endif // __AVX2__
271 
273 #ifdef __AVX2__
274  num_outputs_per_register_ = kNumOutputsPerRegister;
275  max_output_registers_ = kMaxOutputRegisters;
276  num_inputs_per_register_ = kNumInputsPerRegister;
277  num_inputs_per_group_ = kNumInputsPerGroup;
278  num_input_groups_ = kNumInputGroups;
279  partial_funcs_ = {PartialMatrixDotVector64, PartialMatrixDotVector32,
280  PartialMatrixDotVector16, PartialMatrixDotVector8};
281 #endif // __AVX2__
282 }
283 
284 } // namespace tesseract.
std::vector< PartialFunc > partial_funcs_