18 #if !defined(__SSE4_1__)
19 #error Implementation only for SSE 4.1 capable architectures
25 #include <emmintrin.h>
26 #include <smmintrin.h>
32 static int32_t IntDotProductSSE(
const int8_t* u,
const int8_t* v,
int n) {
33 int max_offset = n - 8;
38 if (offset <= max_offset) {
40 __m128i packed1 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(u));
41 __m128i packed2 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(v));
42 __m128i sum = _mm_cvtepi8_epi16(packed1);
43 packed2 = _mm_cvtepi8_epi16(packed2);
47 sum = _mm_madd_epi16(sum, packed2);
48 while (offset <= max_offset) {
49 packed1 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(u + offset));
50 packed2 = _mm_loadl_epi64(reinterpret_cast<const __m128i*>(v + offset));
52 packed1 = _mm_cvtepi8_epi16(packed1);
53 packed2 = _mm_cvtepi8_epi16(packed2);
54 packed1 = _mm_madd_epi16(packed1, packed2);
55 sum = _mm_add_epi32(sum, packed1);
58 sum = _mm_hadd_epi32(sum, sum);
59 sum = _mm_hadd_epi32(sum, sum);
60 result = _mm_cvtsi128_si32(sum);
63 result += u[offset] * v[offset];
70 static void PartialMatrixDotVector1(
const int8_t* wi,
const double* scales,
71 const int8_t* u,
int num_in,
73 double total = IntDotProductSSE(u, wi, num_in);
75 *v = (total / INT8_MAX + wi[num_in]) * *scales;
78 static void matrixDotVector(
int dim1,
int dim2,
const int8_t* wi,
79 const double* scales,
const int8_t* u,
double* v) {
80 const int num_out = dim1;
81 const int num_in = dim2 - 1;
84 for (; output < num_out; output++) {
85 PartialMatrixDotVector1(wi, scales, u, num_in, v);