22 #include <immintrin.h> 30 constexpr
int kNumOutputsPerRegister = 8;
32 constexpr
int kMaxOutputRegisters = 8;
34 constexpr
int kNumInputsPerRegister = 32;
36 constexpr
int kNumInputsPerGroup = 4;
38 constexpr
int kNumInputGroups = kNumInputsPerRegister / kNumInputsPerGroup;
50 inline void MultiplyGroup(
const __m256i& rep_input,
const __m256i& ones,
51 const int8_t*& wi, __m256i& weights, __m256i& reps,
54 weights = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(wi));
55 wi += kNumInputsPerRegister;
57 reps = _mm256_sign_epi8(rep_input, weights);
58 weights = _mm256_sign_epi8(weights, weights);
61 weights = _mm256_maddubs_epi16(weights, reps);
67 weights = _mm256_madd_epi16(weights, ones);
68 result = _mm256_add_epi32(result, weights);
74 inline void ExtractResults(__m256i& result, __m256i& shift_id,
75 const int8_t*& wi,
const double*& scales,
76 int num_out,
double*& v) {
77 for (
int out = 0; out < num_out; ++out) {
80 _mm256_extract_epi32(result, 0)
84 ((int32_t*)&result)[0]
87 *v++ = (
static_cast<double>(res) / INT8_MAX + *wi++) * *scales++;
89 result = _mm256_permutevar8x32_epi32(result, shift_id);
100 static void PartialMatrixDotVector64(
const int8_t* wi,
const double* scales,
101 const int8_t* u,
int num_in,
int num_out,
106 _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
107 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
109 __m256i result0 = _mm256_setzero_si256();
110 __m256i result1 = _mm256_setzero_si256();
111 __m256i result2 = _mm256_setzero_si256();
112 __m256i result3 = _mm256_setzero_si256();
113 __m256i result4 = _mm256_setzero_si256();
114 __m256i result5 = _mm256_setzero_si256();
115 __m256i result6 = _mm256_setzero_si256();
116 __m256i result7 = _mm256_setzero_si256();
118 for (
int j = 0; j < num_in;) {
120 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
123 for (
int ig = 0; ig < kNumInputGroups && j < num_in;
124 ++ig, j += kNumInputsPerGroup) {
127 _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
129 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
130 __m256i weights, reps;
132 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
133 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
134 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
135 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
136 MultiplyGroup(rep_input, ones, wi, weights, reps, result4);
137 MultiplyGroup(rep_input, ones, wi, weights, reps, result5);
138 MultiplyGroup(rep_input, ones, wi, weights, reps, result6);
139 MultiplyGroup(rep_input, ones, wi, weights, reps, result7);
142 ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
143 ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
144 ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
145 ExtractResults(result3, shift_id, wi, scales, kNumOutputsPerRegister, v);
146 ExtractResults(result4, shift_id, wi, scales, kNumOutputsPerRegister, v);
147 ExtractResults(result5, shift_id, wi, scales, kNumOutputsPerRegister, v);
148 ExtractResults(result6, shift_id, wi, scales, kNumOutputsPerRegister, v);
149 num_out -= kNumOutputsPerRegister * 7;
150 ExtractResults(result7, shift_id, wi, scales,
151 std::min(kNumOutputsPerRegister, num_out), v);
156 static void PartialMatrixDotVector32(
const int8_t* wi,
const double* scales,
157 const int8_t* u,
int num_in,
int num_out,
162 _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
163 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
165 __m256i result0 = _mm256_setzero_si256();
166 __m256i result1 = _mm256_setzero_si256();
167 __m256i result2 = _mm256_setzero_si256();
168 __m256i result3 = _mm256_setzero_si256();
170 for (
int j = 0; j < num_in;) {
172 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
175 for (
int ig = 0; ig < kNumInputGroups && j < num_in;
176 ++ig, j += kNumInputsPerGroup) {
179 _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
181 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
182 __m256i weights, reps;
184 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
185 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
186 MultiplyGroup(rep_input, ones, wi, weights, reps, result2);
187 MultiplyGroup(rep_input, ones, wi, weights, reps, result3);
190 ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
191 ExtractResults(result1, shift_id, wi, scales, kNumOutputsPerRegister, v);
192 ExtractResults(result2, shift_id, wi, scales, kNumOutputsPerRegister, v);
193 num_out -= kNumOutputsPerRegister * 3;
194 ExtractResults(result3, shift_id, wi, scales,
195 std::min(kNumOutputsPerRegister, num_out), v);
200 static void PartialMatrixDotVector16(
const int8_t* wi,
const double* scales,
201 const int8_t* u,
int num_in,
int num_out,
206 _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
207 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
209 __m256i result0 = _mm256_setzero_si256();
210 __m256i result1 = _mm256_setzero_si256();
212 for (
int j = 0; j < num_in;) {
214 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
217 for (
int ig = 0; ig < kNumInputGroups && j < num_in;
218 ++ig, j += kNumInputsPerGroup) {
221 _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
223 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
224 __m256i weights, reps;
226 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
227 MultiplyGroup(rep_input, ones, wi, weights, reps, result1);
230 ExtractResults(result0, shift_id, wi, scales, kNumOutputsPerRegister, v);
231 num_out -= kNumOutputsPerRegister;
232 ExtractResults(result1, shift_id, wi, scales,
233 std::min(kNumOutputsPerRegister, num_out), v);
238 static void PartialMatrixDotVector8(
const int8_t* wi,
const double* scales,
239 const int8_t* u,
int num_in,
int num_out,
244 _mm256_set_epi16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1);
245 __m256i shift_id = _mm256_set_epi32(0, 7, 6, 5, 4, 3, 2, 1);
247 __m256i result0 = _mm256_setzero_si256();
249 for (
int j = 0; j < num_in;) {
251 _mm256_loadu_si256(reinterpret_cast<const __m256i*>(u + j));
254 for (
int ig = 0; ig < kNumInputGroups && j < num_in;
255 ++ig, j += kNumInputsPerGroup) {
258 _mm256_broadcastd_epi32(_mm256_castsi256_si128(inputs));
260 inputs = _mm256_permutevar8x32_epi32(inputs, shift_id);
261 __m256i weights, reps;
263 MultiplyGroup(rep_input, ones, wi, weights, reps, result0);
266 ExtractResults(result0, shift_id, wi, scales, num_out, v);
279 partial_funcs_ = {PartialMatrixDotVector64, PartialMatrixDotVector32,
280 PartialMatrixDotVector16, PartialMatrixDotVector8};
int max_output_registers_
std::vector< PartialFunc > partial_funcs_
int num_outputs_per_register_
int num_inputs_per_register_
int num_inputs_per_group_