19 #if !defined(__AVX2__)
20 #error Implementation only for AVX2 capable architectures
25 #include <immintrin.h>
60 static inline void MultiplyGroup(
const __m256i& rep_input,
const __m256i& ones,
61 const int8_t*& wi, __m256i& weights,
62 __m256i& reps, __m256i& result) {
64 weights = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(wi));
67 reps = _mm256_sign_epi8(rep_input, weights);
68 weights = _mm256_sign_epi8(weights, weights);
71 weights = _mm256_maddubs_epi16(weights, reps);
77 weights = _mm256_madd_epi16(weights, ones);
78 result = _mm256_add_epi32(result, weights);
84 static inline void ExtractResults(__m256i& result, __m256i& shift_id,
85 const int8_t*& wi,
const double*& scales,
86 int num_out,
double*& v) {
87 for (
int out = 0; out < num_out; ++out) {
89 auto res = _mm256_extract_epi32(result, 0);
93 auto res = ((int32_t*)&result)[0];
95 *v++ = (static_cast<double>(res) / INT8_MAX + *wi++) * *scales++;
97 result = _mm256_permutevar8x32_epi32(result, shift_id);
108 static void PartialMatrixDotVector64(
const int8_t* wi,
const double* scales,
109 const int8_t* u,
int num_in,
int num_out,
114 _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
115 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
117 __m256i result0 = _mm256_setzero_si256();
118 __m256i result1 = _mm256_setzero_si256();
119 __m256i result2 = _mm256_setzero_si256();
120 __m256i result3 = _mm256_setzero_si256();
121 __m256i result4 = _mm256_setzero_si256();
122 __m256i result5 = _mm256_setzero_si256();
123 __m256i result6 = _mm256_setzero_si256();
124 __m256i result7 = _mm256_setzero_si256();
126 for (
int j = 0; j < num_in;) {
128 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
135 _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
137 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
138 __m256i weights, reps;
140 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
141 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
142 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
143 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
144 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
145 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
146 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
147 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
158 ExtractResults(result7, shift_id, wi, scales,
164 static void PartialMatrixDotVector32(
const int8_t* wi,
const double* scales,
165 const int8_t* u,
int num_in,
int num_out,
170 _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
171 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
173 __m256i result0 = _mm256_setzero_si256();
174 __m256i result1 = _mm256_setzero_si256();
175 __m256i result2 = _mm256_setzero_si256();
176 __m256i result3 = _mm256_setzero_si256();
178 for (
int j = 0; j < num_in;) {
180 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
187 _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
189 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
190 __m256i weights, reps;
192 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
193 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
194 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
195 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
202 ExtractResults(result3, shift_id, wi, scales,
208 static void PartialMatrixDotVector16(
const int8_t* wi,
const double* scales,
209 const int8_t* u,
int num_in,
int num_out,
214 _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
215 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
217 __m256i result0 = _mm256_setzero_si256();
218 __m256i result1 = _mm256_setzero_si256();
220 for (
int j = 0; j < num_in;) {
222 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
229 _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
231 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
232 __m256i weights, reps;
234 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
235 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
240 ExtractResults(result1, shift_id, wi, scales,
246 static void PartialMatrixDotVector8(
const int8_t* wi,
const double* scales,
247 const int8_t* u,
int num_in,
int num_out,
252 _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
253 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
255 __m256i result0 = _mm256_setzero_si256();
257 for (
int j = 0; j < num_in;) {
259 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
266 _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
268 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
269 __m256i weights, reps;
271 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
274 ExtractResults(result0, shift_id, wi, scales, num_out, v);
277 static void matrixDotVector(
int dim1,
int dim2,
const int8_t* wi,
278 const double* scales,
const int8_t* u,
double* v) {
279 const int num_out = dim1;
280 const int num_in = dim2 - 1;
283 const int rounded_num_in =
285 const int rounded_num_out =
290 int w_step = (rounded_num_in + 1) * group_size;
294 for (; output + group_size <= rounded_num_out; output += group_size) {
295 PartialMatrixDotVector64(wi, scales, u, rounded_num_in, num_out - output, v);
297 scales += group_size;
303 for (; output + group_size <= rounded_num_out; output += group_size) {
304 PartialMatrixDotVector32(wi, scales, u, rounded_num_in, num_out - output, v);
306 scales += group_size;
312 for (; output + group_size <= rounded_num_out; output += group_size) {
313 PartialMatrixDotVector16(wi, scales, u, rounded_num_in, num_out - output, v);
315 scales += group_size;
321 for (; output + group_size <= rounded_num_out; output += group_size) {
322 PartialMatrixDotVector8(wi, scales, u, rounded_num_in, num_out - output, v);
324 scales += group_size;