提问人:Carl 提问时间:11/16/2023 更新时间:11/16/2023 访问量:66
SIMD 算法,用于检查整数块是否为“连续”。
SIMD algorithm to check of if an integer block is "consecutive."
问:
如何检查对齐的 16 个块是否连续(并且不断增加)?u32
例如:is。
而且,不是。[100, 101, 102, ..., 115]
[100, 99, 3 ...]
我在 AVX512f 上。这是我目前所拥有的:
算法 A:
* predefine DECREASE_U32, a u32x16 of [15,14,13,...0]
* let a = input + DECREASE_32 // wrapping is OK
* compare a to u32x16::splat(first_item(a))
* Return whether all true
交替(算法 B)
* let b = copy of A
* permute the elements of b by one position
* let b = a-b
* Is b all 1's (except for 1st position)
我在 Rust 中使用板条箱执行此操作,但任何语言/伪代码都可以。(我希望有一个 SIMD 操作来减去相邻的项目。packed_simd
答:
我认为你的第一个想法可能是最好的,如果在一个循环中完成,可以摊销加载向量常量的成本。AVX-512 可以有效地做到这一点。
要么使用矢量负载,然后使用 ,要么使用矢量负载和广播负载分别广播低元素。例如 / .vpbroadcastd
vpaddd zmm16, zmm31, [rdi]{1to16}
vpcmpeqd k1, zmm16, [rdi]
嗯,但是然后检查所有元素是否为真,我想也许会有一个常数,并检查低 16 位是否为零?或者只是一个整数寄存器进行比较,就像我们对 SSE/AVX 所做的那样。我试了一下,clang有一个更好的主意:比较不相等,并检查掩码是否全部为零。(即,通过检查每个元素是否不相等来检查它们是否相等。这允许在面具本身上。我将 clang 的想法应用到我的内部函数中,这样 GCC 也可以做出更好的 asm。kaddw
1
kortest
kmov
0xffff
pmovmskb
kortest
在 C++ 中:
#include <immintrin.h>
// compare for not-equal, checking the mask for 0
bool check_contig(int *p)
{
__m512i bcast_first = _mm512_set1_epi32(*p);
__m512i desired = _mm512_add_epi32(bcast_first, _mm512_setr_epi32(15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0));
__m512i v = _mm512_loadu_si512(p);
__mmask16 cmp = _mm512_cmpneq_epi32_mask(desired, v);
return cmp == 0;
}
Godbolt - 来自 GCC 和 clang 的 asm:
# GCC
check_contig(int*):
vmovdqa32 zmm0, ZMMWORD PTR .LC0[rip]
vpaddd zmm0, zmm0, DWORD PTR [rdi]{1to16}
vpcmpd k0, zmm0, ZMMWORD PTR [rdi], 4
kortestw k0, k0
sete al
vzeroupper
ret
# clang
check_contig(int*):
vpbroadcastd zmm0, dword ptr [rdi]
vpaddd zmm0, zmm0, zmmword ptr [rip + .LCPI0_0]
vpcmpneqd k0, zmm0, zmmword ptr [rdi]
kortestw k0, k0
sete al
vzeroupper
ret
因此,它们都选择加载两次而不是 ,至少在不循环中时是这样,因此向量常量也必须从 .vpbroadcastd zmm1, xmm0
.rodata
也许如果我以不同的方式写它,因为他们更喜欢一个加载,代价是额外的随机播放。(当您有 512 位 uops 在运行时,情况可能会更糟,因此 Intel CPU 会关闭端口 1 上的矢量 ALU,只留下端口 0 和 5。https://agner.org/optimize/ 和 https://uops.info/_mm512_broadcastd_epi32( _mm512_castsi512_si128(v))
)
算法 B - 避免非平凡的向量常数
也许您的第二种方法也可以有效地旋转矢量;它唯一需要的向量常数是 all-1,它可以更便宜地生成 () 而不是加载。valignd
vpternlogd
检查比较掩码可能需要一个 + 的整数来检查除一位之外的所有位,除非我们可以使用与 clang 相同的技巧并安排事情,因此我们实际上希望掩码在我们想要的地方全为零。在这种情况下,可以检查我们想要的位,而忽略我们不需要的位。kmov
and
cmp
test eax, imm32
评论
我当前 Rust 代码的核心现在是这个宏代码:
const LAST_INDEX: usize = <$simd>::lanes() - 1;
let (expected, overflowed) = $chunk[0].overflowing_add(LAST_INDEX as $scalar);
if overflowed || expected != $chunk[LAST_INDEX] {
return false;
}
let a = unsafe { <$simd>::from_slice_aligned_unchecked($chunk) } + $decrease;
let compare_mask = a.eq(<$simd>::splat(a.extract(0)));
compare_mask.all()
其中 $scalar 是 ,$simd 是,$decrease 是 [15, 14 ...0] 块。代码的第一部分抽查最后一个元素是否比第一个元素多 15 个(并处理溢出)。u32
u32x16
我请了一个智能工具来帮助我了解所生产的 SIMD 组件。它说:
vmovdqa64: 此指令将 512 位数据向量移动到 ZMM 寄存器中。 它在这里使用了两次: vmovdqa64 zmm0,zmmword ptr [...]: 加载一个 512 位向量从内存到 zmm0。vmovdqa64 zmm0,zmmword ptr [...] (在代码后面):将不同的 512 位向量加载到 zmm0 中。 vpaddd:
vpaddd zmm0,zmm0,zmmword ptr [rax+40h]:执行打包整数 添加 32 位整数。此指令添加 512 位向量 在 zmm0 中到另一个 512 位向量(从 rax + 40h) 并将结果存储回 zmm0 中。VP广播:
vpbroadcastd zmm1,xmm0:广播来自 xmm0 的 32 位整数(较低 128 位 zmm0) 跨越 zmm1 的所有通道。这将创建一个 512 位 zmm1 中的向量,其中所有元素都相同且等于值 在 xmm0 中。vpcmpeqd:
vpcmpeqd k0,zmm0,zmm1:比较 zmm0 和 zmm1 中的 32 位整数 平等。结果存储在掩码寄存器 k0 中,其中每个位 表示每对元素的比较结果。 vpternlogd:
vpternlogd zmm1,zmm1,zmm1,0FFh:执行按位三元逻辑 对操作数的每个位进行操作。具体操作是 由即时值 0xFF 确定,在本例中对应于 到按位 OR。VPMOVM2D:
vpmovm2d zmm0,k0:将位掩码从掩码寄存器 k0 移动到 通用寄存器 ZMM0。k0 的每个位都成为一个 32 位元素 在 zmm0 中。vpcmpd:
vpcmpd k0,zmm0,zmm1,4:比较 zmm0 和 zmm1 中的 32 位整数 根据作为最后一个操作数提供的谓词(这里是 4,其中 通常表示“小于”)。结果存储在掩码中 注册 K0。vmovdqu64:
vmovdqu64 zmmword ptr [rsp+50h],zmm0:移动 zmm0 中的 512 位向量 在地址 rsp + 50h 处进入内存。
kortestw k0,k0:测试掩码寄存器 k0 的内容并设置 基于结果的零标志。这通常用于条件 基于 SIMD 比较结果的分支。
VzeroUpper: 此指令用于清除所有 YMM 的上 256 位 寄存器以避免在混合 AVX-512 和传统 SSE 代码时受到处罚。 最好在调用函数之前使用此指令 AVX-512 可能不知道。
评论
vpternlogd z,z,z, 0xff
不是按位 OR。它在输出中产生全 1,因为真值表的所有 8 位都是 。即它是编译器实现 .哎呀,有没有编译成将掩码变回比较的矢量?这太可怕了。就像我在回答中解释的那样,最坏的情况只是通用寄存器(或像 GCC 使用的那样)。充其量,做相反的事情,所以你只需要检查掩码是否为全零而不是全一,你可以直接用它来做1
_mm512_set1_epi32(-1)
compare_mask.all()
kmovw eax, k0
cmp eax, 0xffff
cmp ax, -1
.eq
kortestw
compare_mask.all()
ne
eq
vmovdqu64 [rsp+50h],zmm0
0xffffffff
0
cmp dword [rdi], 0 - 15
jae would_wrap
.ne
上一个:你如何四舍五入一个数字?
评论