提问人:kaisong 提问时间:11/7/2022 最后编辑:kaisong 更新时间:11/7/2022 访问量:177
如何借助 avx2 内部函数为 Zen2 编写高效的 GEMM 微内核?
How to write an efficient GEMM micro-kernel for Zen2 by virtue of avx2 intrinsics?
问:
我希望能够编写快速内核,当表达式足够计算密集时,这些内核实际上可以充分利用*(*90% 也可以)使用我的硬件的计算能力。顺便说一句,在上一个问题中,我问了同样的内存性能而不是计算性能。
对于手头的问题,我计算 C += A*B',其中 A 和 B 是列大,A 在 mxk 中,B 在 kxn 中,C 在 mxn 中。我使用小 m,n 和 k=1024,因此所有数据都适合 L1 缓存。然后我重复运行内核。然后,我使用 AVX2 内部函数调整内核。
我的架构是 Zen2。每个内核有两个 AVX2 单元和 16 个寄存器,峰值性能为 5.9e+1 GigaFlops/秒。因此,我尝试使用至少 8 个独立的 FMA 指令,最多使用 16 个 avx2 寄存器。这将产生 mxn 为 4x8 或 3x12 的内核大小。我还实现了一个 4x4 内核。下面给出了所有内核及其基准性能。
void kernel3x12(double* __restrict__ c, double* __restrict__ a, double* __restrict__ b, size_t L){
//
__m256d a0,a1,a2, b0,b1,b2;
//
__m256d c00 = _mm256_setzero_pd(); __m256d c01 = _mm256_setzero_pd(); __m256d c02 = _mm256_setzero_pd();
__m256d c10 = _mm256_setzero_pd(); __m256d c11 = _mm256_setzero_pd(); __m256d c12 = _mm256_setzero_pd();
__m256d c20 = _mm256_setzero_pd(); __m256d c21 = _mm256_setzero_pd(); __m256d c22 = _mm256_setzero_pd();
//
for(size_t k=0;k<L;++k){
//
a0 = _mm256_set1_pd(a[ 0]);
a1 = _mm256_set1_pd(a[ 1]);
a2 = _mm256_set1_pd(a[ 2]);
//
b0 = _mm256_loadu_pd(b );
b1 = _mm256_loadu_pd(b+4 );
b2 = _mm256_loadu_pd(b+8 );
//
c00 = _mm256_fmadd_pd(a0,b0,c00); c01 = _mm256_fmadd_pd(a0,b1,c01); c02 = _mm256_fmadd_pd(a0,b2,c02);
c10 = _mm256_fmadd_pd(a1,b0,c10); c11 = _mm256_fmadd_pd(a1,b1,c11); c12 = _mm256_fmadd_pd(a1,b2,c12);
c20 = _mm256_fmadd_pd(a2,b0,c20); c21 = _mm256_fmadd_pd(a2,b1,c21); c22 = _mm256_fmadd_pd(a2,b2,c22);
//
a+=3; b+=12;
//
}
//
_mm256_storeu_pd(c ,_mm256_add_pd(c00,_mm256_loadu_pd(c ))); _mm256_storeu_pd(c+4 ,_mm256_add_pd(c01,_mm256_loadu_pd(c+4 ))); _mm256_storeu_pd(c+8 ,_mm256_add_pd(c02,_mm256_loadu_pd(c+8 )));
_mm256_storeu_pd(c+12,_mm256_add_pd(c10,_mm256_loadu_pd(c+12))); _mm256_storeu_pd(c+16,_mm256_add_pd(c11,_mm256_loadu_pd(c+16))); _mm256_storeu_pd(c+20,_mm256_add_pd(c12,_mm256_loadu_pd(c+20)));
_mm256_storeu_pd(c+24,_mm256_add_pd(c20,_mm256_loadu_pd(c+24))); _mm256_storeu_pd(c+28,_mm256_add_pd(c21,_mm256_loadu_pd(c+28))); _mm256_storeu_pd(c+32,_mm256_add_pd(c22,_mm256_loadu_pd(c+32)));
//
}
// 1.61+1 GigaFlops/second
void kernel_4x8(double* __restrict__ c, double* __restrict__ a, double* __restrict__ b, size_t L){
//
__m256d a0,a1,a2,a3, b0,b1;
//
__m256d c00 = _mm256_setzero_pd(); __m256d c01 = _mm256_setzero_pd();
__m256d c10 = _mm256_setzero_pd(); __m256d c11 = _mm256_setzero_pd();
__m256d c20 = _mm256_setzero_pd(); __m256d c21 = _mm256_setzero_pd();
__m256d c30 = _mm256_setzero_pd(); __m256d c31 = _mm256_setzero_pd();
//
for(size_t k=0;k<L;++k){
//
a0 = _mm256_set1_pd(a[ 0]);
a1 = _mm256_set1_pd(a[ 1]);
a2 = _mm256_set1_pd(a[ 2]);
a3 = _mm256_set1_pd(a[ 3]);
//
b0 = _mm256_loadu_pd(b );
b1 = _mm256_loadu_pd(b+4 );
//
c00 = _mm256_fmadd_pd(a0,b0,c00); c01 = _mm256_fmadd_pd(a0,b1,c01);
c10 = _mm256_fmadd_pd(a1,b0,c10); c11 = _mm256_fmadd_pd(a1,b1,c11);
c20 = _mm256_fmadd_pd(a2,b0,c20); c21 = _mm256_fmadd_pd(a2,b1,c21);
c30 = _mm256_fmadd_pd(a3,b0,c20); c31 = _mm256_fmadd_pd(a3,b1,c31);
//
a+=4; b+=8;
//
}
//
_mm256_storeu_pd(c ,_mm256_add_pd(c00,_mm256_loadu_pd(c ))); _mm256_storeu_pd(c+4 ,_mm256_add_pd(c01,_mm256_loadu_pd(c+4 )));
_mm256_storeu_pd(c+8 ,_mm256_add_pd(c10,_mm256_loadu_pd(c+8 ))); _mm256_storeu_pd(c+12,_mm256_add_pd(c11,_mm256_loadu_pd(c+12)));
_mm256_storeu_pd(c+16,_mm256_add_pd(c20,_mm256_loadu_pd(c+16))); _mm256_storeu_pd(c+20,_mm256_add_pd(c21,_mm256_loadu_pd(c+20)));
_mm256_storeu_pd(c+24,_mm256_add_pd(c20,_mm256_loadu_pd(c+24))); _mm256_storeu_pd(c+28,_mm256_add_pd(c21,_mm256_loadu_pd(c+28)));
//
}
// 1.72e+1 GigaFlops/second
void kernel(double* __restrict__ c, double* __restrict__ a, double* __restrict__ b, size_t L){
//
__m256d b0, a0,a1,a2,a3;
//
__m256d c0 = _mm256_setzero_pd();
__m256d c1 = _mm256_setzero_pd();
__m256d c2 = _mm256_setzero_pd();
__m256d c3 = _mm256_setzero_pd();
//
for(size_t k=0;k<L;++k){
//
b0 = _mm256_loadu_pd(b );
//
a0 = _mm256_set1_pd(a[ 0]);
a1 = _mm256_set1_pd(a[ 1]);
a2 = _mm256_set1_pd(a[ 2]);
a3 = _mm256_set1_pd(a[ 3]);
//
c0 = _mm256_fmadd_pd( a0 , b0 , c0 );
c1 = _mm256_fmadd_pd( a1 , b0 , c1 );
c2 = _mm256_fmadd_pd( a2 , b0 , c2 );
c3 = _mm256_fmadd_pd( a3 , b0 , c3 );
//
a+=4; b+=4;
//
}
//
_mm256_storeu_pd( c , _mm256_add_pd( c0 , _mm256_loadu_pd(c ) ) );
_mm256_storeu_pd( c+ 4 , _mm256_add_pd( c1 , _mm256_loadu_pd(c+ 4) ) );
_mm256_storeu_pd( c+ 8 , _mm256_add_pd( c2 , _mm256_loadu_pd(c+ 8) ) );
_mm256_storeu_pd( c+12 , _mm256_add_pd( c3 , _mm256_loadu_pd(c+12) ) );
//
}
// 2.27e+1 GigaFlops/second
编译命令是
g++ -c kernel.cpp -o kernel.o -O3 -mavx -ffast-math -march=native -fno-trapping-math
我的问题是:为了获得至少 90% 的峰值性能,我必须改变什么?
为方便起见,完整的代码
/*
g++ -c kernel.cpp -o kernel.o -O3 -mavx -ffast-math -march=native -fno-trapping-math
g++ -c main.cpp -o main.o
g++ -o main.exe main.o kernel.o
./main.exe
*/
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<iomanip>
#include<random>
#include<chrono>
#include"kernel.hpp"
#define dim_m 4
#define dim_n 4
using std::cout;
using std::endl;
using std::chrono::high_resolution_clock;
using std::chrono::duration;
using std::setw;
struct rng{
std::random_device rd;
std::mt19937 mt;
std::normal_distribution<> dist;
rng(): rd(), mt(rd()), dist(0.0,1.0) {}
double get(){ return dist(mt); }
};
void set_random(double* x, size_t n, rng& o){
for(size_t i=0;i<n;++i){
x[i] = o.get();
}
}
void print(double* c, size_t L){
for(int i=0;i<dim_m;++i){
for(int j=0;j<dim_n;++j){
std::cout << std::setw(20) << c[i*dim_n+j] << " ";
}
std::cout << "\n";
}
}
void kernel_ref(double* c, double* a, double* b, size_t L){
for(size_t k=0;k<L;++k){
for(size_t i=0;i<dim_m;++i){
for(size_t j=0;j<dim_n;++j){
c[i*dim_n+j] += a[k*dim_m+i] * b[k*dim_n+j];
}
}
}
}
int main(){
rng o;
int L = 512;
int J = 100000;
double *a, *b, *c;
std::chrono::time_point<std::chrono::high_resolution_clock> t0,t1;
// size_of(double)=8
// aligned allocation
/*
a = std::aligned_alloc(L*4*sizeof(double),4*sizeof(double)); // alignment size, vector size in bytes must
b = std::aligned_alloc(L*4*sizeof(double),4*sizeof(double));
c = std::aligned_alloc(4*4*sizeof(double),4*sizeof(double)); // look-up for g++
*/
a = new double[L*dim_m];
b = new double[L*dim_n];
c = new double[dim_m*dim_n];
set_random(a,dim_m*L ,o);
set_random(b,L *dim_n,o);
set_random(c,dim_m*dim_n,o);
int branch = 3;
if(branch==1){
kernel_ref(c,a,b,L);
}
if(branch==2){
kernel(c,a,b,L);
}
if(branch==3){
t0 = high_resolution_clock::now();
for(int j=0;j<J;++j){
kernel(c,a,b,L);
}
t1 = high_resolution_clock::now();
}
print(c,L);
double flops = 2.0 * dim_m*dim_n * L * J;
double time = std::chrono::duration<double>(t1-t0).count();
double perf = flops / time;
/*
std::free(a);
std::free(b);
std::free(c);
*/
delete[] a;
delete[] b;
delete[] c;
cout << "\n--------------------------------------------------\n" << std::right << std::scientific;
cout << std::left << "measures:\n";
cout << std::left << "\t" << setw(20) << "flops" << " : " << std::right << flops << "\n";
cout << std::left << "\t" << setw(20) << "time" << " : " << std::right << time << "\n";
cout << std::left << "\t" << setw(20) << "perf" << " : " << std::right << perf * 1.0e-09 << " [GigaFlops/s] ~ "<< 16.0*3.7e+9 * 1.0e-09<<"\n"; // https://en.wikichip.org/wiki/amd/microarchitectures/zen_2#Floating_Point_Unit
cout << std::left << "\t" << setw(20) << "perf x32" << " : " << std::right << 32*perf * 1.0e-12 << " [TeraFlops/s] ~ "<<32.0*16.0*3.7e+9 * 1.0e-12<<"\n";
return 0;
}
然后使用kernel.cpp的头文件。
结果在 Windows 10 中使用以下 MinGW 编译器。
g++ -v
Using built-in specs.
COLLECT_GCC=C:\MinGW\bin\g++.exe
COLLECT_LTO_WRAPPER=c:/mingw/bin/../libexec/gcc/mingw32/8.2.0/lto-wrapper.exe
Target: mingw32
Configured with: ../src/gcc-8.2.0/configure --build=x86_64-pc-linux-gnu --host=mingw32 --target=mingw32 --prefix=/mingw --disable-win32-registry --with-arch=i586 --with-tune=generic --enable-languages=c,c++,objc,obj-c++,fortran,ada --with-pkgversion='MinGW.org GCC-8.2.0-5' --with-gmp=/mingw --with-mpfr=/mingw --with-mpc=/mingw --enable-static --enable-shared --enable-threads --with-dwarf2 --disable-sjlj-exceptions --enable-version-specific-runtime-libs --with-libiconv-prefix=/mingw --with-libintl-prefix=/mingw --enable-libstdcxx-debug --with-isl=/mingw --enable-libgomp --disable-libvtv --enable-nls --disable-build-format-warnings
Thread model: win32
gcc version 8.2.0 (MinGW.org GCC-8.2.0-5)
PS:这个问题的目的不是为了获得一个快速的GEMM微内核,而是学习如何编写代码,以充分利用可用的计算资源。为了简化起见,将 FLOPS 内核限制在 L1-Cache 内。
答: 暂无答案
评论
-mavx
_mm_fmadd_pd
-mfma
-march=znver2
-march=native
-mavx
-mavx
)