SIMD: avx2_select_if 函数实现

本文主要是讲述下 StarRocks 中如何利用 SIMD 指令来实现 select_if 的,其中关于 SIMD 的相关 API 含义直接参考 intel-intrinsics-guide 文档

avx2_select_if_common_impl

调用 avx2_select_if_common_impl 函数的前提是 sizeof(T) 是 {2,4,8} 中的一个。

1
2
3
4
template <typename T>
constexpr bool could_use_common_select_if() {
return sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8;
}

select_if 的简单实现就是使用一个 for-loop 对 conditions 中每个条件进行判断:

1
2
for cond, a, b in zip(conditions, vector_a, vector_b):
dst.append(a if cond == True else b)

要用 SIMD 指令来实现 avx2_select_if, 思路就是将原本一次处理 8bit 转变为现在一次性处理 256 bit(假设支持 avx2 指令集):
avx2_select_if-1

T1-load

首先需要读取 selector、a、b 中的值,每次读取 256 bit,保存为 __m256i 类型。__m256i 实际上就是存储类型为 integer、大小为 256 bit 的数组。

  1. 使用 _mm256_loadu_si256 函数将数据从内存读取到寄存器中,需要保证读取的内存确实有 256 bit 的连续内存,否则会触发 segment fault。

  2. selector 转化为 __m256i 对象 loaded_mask 后,需要逐字节(bytewise)和 0x00 进行比较

    这里使用 _mm256_cmpeq_epi8 函数逐字节将 loaded_mask 和 0x00 进行比较:相等的为 0x00 ,不等为 0xff。再对比较结果进行取反,相等的为 0xff,不等为 0x00。

    这一步作用:由于 selector 是 uint8_t[] 类型,每个元素的值要么是 0x00 要么是 0x01,通过 comp 和取反操作后,0x00 仍然是 0x00,而 0x01 变成 0xff,可以看做是『适配』SIMD API 的行为,因为很多 SIMD API 操作依赖 MEM[i*8+7] 的值来进行判断,后续比较方便。

  3. 使用 _mm256_movemask_epi8 函数将 loaded_mask 中每个字节的最高位组合起来,生成 32 bit = 256 bit / 8 整数 mask。

    至于这一步为啥怎么做,还不太清楚?后面还有个 data_mask 用来将 mask 扩充

模板参数中的 left_const 和 right_const 表示比较的左侧和右侧是否是个常数:

1
2
select .. where a > 10; -- left_const = true, right_const = false
select .. where a > b; -- left_const = false, right_const = false

读取 a、b 的数据至 vec_a、vec_b 的过程比较简单,这部分代码及注释如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
template <typename T, bool left_const = false, bool right_const = false>
inline void avx2_select_if_common_impl(uint8_t*& selector, T*& dst,
const T*& a, const T*& b, int size) {
const T* dst_end = dst + size;
constexpr int data_size = sizeof(T);

// 需要确保还剩余足够的内存
while (dst + 32 < dst_end) {
__m256i loaded_mask = _mm256_loadu_si256(reinterpret_cast<__m256i*>(selector));
loaded_mask = _mm256_cmpeq_epi8(loaded_mask, _mm256_setzero_si256());
loaded_mask = ~loaded_mask;
uint32_t mask = _mm256_movemask_epi8(loaded_mask);

__m256i vec_a[data_size];
__m256i vec_b[data_size];
__m256i vec_dst[data_size];

// load data from data vector
for (int i = 0; i < data_size; ++i) {
if constexpr (!left_const) {
vec_a[i] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a) + i);
} else {
vec_a[i] = SIMDUtils::set_data(*a);
}
if constexpr (!right_const) {
vec_b[i] = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b) + i);
} else {
vec_b[i] = SIMDUtils::set_data(*b);
}
}
//...
}

T2-Blendv

由于已经基于 a, b 生成 __m256i 对象 vec_a, vec_b,向量化实现 avx2_select_if 的关键在于生成 __m256i 的 condition,最终调用 _mm256_blendv_epi8 函数来实现向量化三元表达式。

由于在 T1-load 中已经将 selector 转化为 uint32_t 类型的 mask,下面需要将 mask 恢复成 __m256i。

data_size 的值是 sizeof(T)(比如 uint16_t 的 data_size 是 2),for-loop 迭代 data_size 次,因此每次迭代处理 each_loop_handle_sz bit = 32 / data_size bit(其中 32 表征 mask 是 uint32_t 类型),即每次处理向量中的元素 vec_a[i],vec_b[i] 的位数。

因此, uint16_t 就是每次迭代处理 16bit, uint32_t 每次处理 8 bit, 而 uint64_t 每次处理 4 bit,对应到 mask_table 中就是 0xFFFF、0xFF、0X0F,即通过 select_mask = mask & mask_table[data_size] 来取出每次迭代操作的位。

下面以 uint16 为例。

  1. 在第一轮迭代中 select_mask 取的是 mask 的低 16 位,高 16 位为 0,即

    1
    2
    select_mask[15:0] = mask[15:0]
    select_mask[31:16] = 0

    _mm256_set1_epi16(select_mask) 函数用 select_mask[15:0] 填充满 256 bit,即 select_vector 中连续存储了 16 个 select_mask[15:0],效果就类似于 memset。

    由于 select_mask[15:0] 中的每一个bit,对应着 loaded_mask 中 1 个字节最高位,因此这里为了取出该字节,使用 data_mask 恢复。比如 data_mask 前五个数据是:

    1
    2
    3
    4
    5
    6
    0x0001 // 取的是第一个字节
    0x0002 // 取的是第二个字节
    0x0004 // 取的是第三个字节
    0x0008 // 取的是第四个字节
    0x0010 // 取的是第五个字节
    ...

    再将 select_vector = select_vector & data_mask,得到就是每个字节的 mask。

    由于 _mm256_blendv_epi8 函数是按照 mask[i*8+7] 位置进行判断的,因此,需要将 select_vector 中 0x01 扩充成为 0xff,这就是 _mm256_cmpeq_epi16 函数和取反操作的目的。

    到此,第一轮结束,_mm256_blendv_epi8 函数所需的 vec_a, vec_b, select_vector 都具备了,则可以一次性处理 256 bit,结构存于 vec_dst。

  2. 下一轮迭代即处理 mask 的高 16 位

    此时只需要将 mask >> each_loop_handle_sz,即高 16 位变成低 16 位,重复上述逻辑。

data_size 为 4、8的逻辑类似,代码如下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
//...
// 2.
constexpr uint32_t mask_table[] = {0, 0xFFFFFFFF, 0xFFFF,
0, 0xFF,
0, 0, 0, 0x0F,
0, 0, 0, 0, 0, 0, 0, 0x03};
constexpr uint8_t each_loop_handle_sz = 32 / data_size;
for (int i = 0; i < data_size; ++i) {
uint32_t select_mask = mask & mask_table[data_size];
__m256i select_vector;
if constexpr (data_size == 2) {
select_vector = _mm256_set1_epi16(select_mask);
const __m256i data_mask = _mm256_setr_epi16(
0x0001, 0x0002, 0x0004, 0x0008, 0x0010, 0x0020, 0x0040, 0x0080,
0x0100, 0x0200, 0x0400, 0x0800, 0x1000, 0x2000, 0x4000, 0x8000);

select_vector &= data_mask;
select_vector = _mm256_cmpeq_epi16(select_vector, _mm256_setzero_si256());
select_vector = ~select_vector;
} else if constexpr (data_size == 4) {
select_vector = _mm256_set1_epi8(select_mask);
const __m256i data_mask = _mm256_setr_epi8(
0x00, 0x00, 0x00, 0x01,
0x00, 0x00, 0x00, 0x02,
0x00, 0x00, 0x00, 0x04,
0x00, 0x00, 0x00, 0x08,
0x00, 0x00, 0x00, 0x10,
0x00, 0x00, 0x00, 0x20,
0x00, 0x00, 0x00, 0x40,
0x00, 0x00, 0x00, 0x80);

select_vector &= data_mask;
select_vector = _mm256_cmpeq_epi32(select_vector, _mm256_setzero_si256());
select_vector = ~select_vector;
} else if constexpr (data_size == 8) {
select_vector = _mm256_set1_epi8(select_mask);
const __m256i data_mask = _mm256_setr_epi8(
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x01,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x02,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x04,
0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x08);

select_vector &= data_mask;
select_vector = _mm256_cmpeq_epi64(select_vector, _mm256_setzero_si256());
select_vector = ~select_vector;
}

vec_dst[i] = _mm256_blendv_epi8(vec_b[i], vec_a[i], select_vector);
mask >>= each_loop_handle_sz;
}
//...

T3-store

T3-store 就是存储结果:将寄存器 vec_dst 中的数据存储到内存 dst 中,并继续迭代下一个 256bit。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
template <typename T, bool left_const = false, bool right_const = false>
inline void avx2_select_if_common_impl(uint8_t*& selector, T*& dst,
const T*& a, const T*& b, int size) {
while (dst + 32 < dst_end) {
//...
// 3. 保存结果
for (int i = 0; i < data_size; ++i) {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(dst) + i, vec_dst[i]);
}

// 4. 迭代
dst += 32;
selector += 32;
if constexpr (!left_const) {
a += 32;
}
if constexpr (!right_const) {
b += 32;
}
}
}

avx2_select_if

这个版本的 avx2_select_if 是一个特化,只适用于 sizeof(T) 为 1 的情况,是 avx2_select_if_common_impl 的简略版,不需要 data_mask 了。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
template <typename T, 
bool left_const = false, bool right_const = false,
std::enable_if_t<sizeof(T) == 1, int> = 1>
inline void avx2_select_if(uint8_t*& selector, T*& dst,
const T*& a, const T*& b, int size) {
const T* dst_end = dst + size;
while (dst + 32 < dst_end) {
auto loaded_mask = _mm256_loadu_si256(reinterpret_cast<__m256i*>(selector));
loaded_mask = _mm256_cmpeq_epi8(loaded_mask, _mm256_setzero_si256());
loaded_mask = ~loaded_mask;
__m256i vec_a;
__m256i vec_b;
if constexpr (!left_const) {
vec_a = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a));
} else {
vec_a = _mm256_set1_epi8(*a);
}
if constexpr (!right_const) {
vec_b = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b));
} else {
vec_b = _mm256_set1_epi8(*b);
}
__m256i res = _mm256_blendv_epi8(vec_b, vec_a, loaded_mask);
_mm256_storeu_si256(reinterpret_cast<__m256i*>(dst), res);
dst += 32;
selector += 32;
if (!left_const) {
a += 32;
}
if (!right_const) {
b += 32;
}
}
}

avx2_select_if

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
template <typename T, std::enable_if_t<sizeof(T) == 4, int> = 4>
inline void avx2_select_if(uint8_t*& selector, T*& dst,
const T*& a, const T*& b, int size) {
const T* dst_end = dst + size;

while (dst + 8 < dst_end) {
uint64_t value = UNALIGNED_LOAD64(selector);
__m128i v = _mm_set1_epi64x(value);
__m256i loaded_mask = _mm256_cvtepi8_epi32(v);
__m256i cond = _mm256_cmpeq_epi8(loaded_mask, _mm256_setzero_si256());
cond = ~cond;

__m256i mask = _mm256_set_epi8(
0x0c, 0xff, 0xff, 0xff,
0x08, 0xff, 0xff, 0xff,
0x04, 0xff, 0xff, 0xff,
0x00, 0xff, 0xff, 0xff,
0x0c, 0xff, 0xff, 0xff,
0x08, 0xff, 0xff, 0xff,
0x04, 0xff, 0xff, 0xff,
0x00, 0xff, 0xff, 0xff);
cond = _mm256_shuffle_epi8(cond, mask);

__m256i vec_a = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(a));
__m256i vec_b = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(b));
__m256 res = _mm256_blendv_ps(_mm256_castsi256_ps(vec_b),
_mm256_castsi256_ps(vec_a),
_mm256_castsi256_ps(cond));
_mm256_storeu_si256(
reinterpret_cast<__m256i*>(dst), _mm256_castps_si256(res));

dst += 8;
selector += 8;
a += 8;
b += 8;
}
}

SIMD_selector

SIMD_selector 其中的一个 select_if 函数封装上述三个函数,其他也类似。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
template <PrimitiveType TYPE>
class SIMD_selector {
public:
using Container = typename RunTimeColumnType<TYPE>::Container;
using CppType = RunTimeCppType<TYPE>;
using SelectVec = uint8_t*;

// select if var var
// dst[i] = select_vec[i] ? a[i] : b[i]
static void select_if(SelectVec select_vec, Container& dst,
const Container& a, const Container& b) {
int size = dst.size();
auto* start_dst = dst.data();
auto* end_dst = dst.data() + size;

auto* start_a = a.data();
auto* start_b = b.data();

#ifdef __AVX2__
if constexpr (sizeof(CppType) == 1) {
avx2_select_if(select_vec, start_dst, start_a, start_b, size);
} else if constexpr (sizeof(CppType) == 4) {
avx2_select_if(select_vec, start_dst, start_a, start_b, size);
} else if constexpr (could_use_common_select_if<CppType>()) {
avx2_select_if_common_impl(select_vec, start_dst, start_a, start_b, size);
}
#endif

while (start_dst < end_dst) {
*start_dst = *select_vec ? *start_a : *start_b;
select_vec++;
start_dst++;
start_a++;
start_b++;
}
}
}