如何借助 avx2 内部函数为 Zen2 编写高效的 GEMM 微内核?

How to write an efficient GEMM micro-kernel for Zen2 by virtue of avx2 intrinsics?

提问人:kaisong 提问时间:11/7/2022 最后编辑:kaisong 更新时间:11/7/2022 访问量:177

问:

我希望能够编写快速内核,当表达式足够计算密集时,这些内核实际上可以充分利用*(*90% 也可以)使用我的硬件的计算能力。顺便说一句,在上一个问题中,我问了同样的内存性能而不是计算性能。

对于手头的问题,我计算 C += A*B',其中 A 和 B 是列大,A 在 mxk 中,B 在 kxn 中,C 在 mxn 中。我使用小 m,n 和 k=1024,因此所有数据都适合 L1 缓存。然后我重复运行内核。然后,我使用 AVX2 内部函数调整内核。

我的架构是 Zen2。每个内核有两个 AVX2 单元和 16 个寄存器,峰值性能为 5.9e+1 GigaFlops/秒。因此,我尝试使用至少 8 个独立的 FMA 指令,最多使用 16 个 avx2 寄存器。这将产生 mxn 为 4x8 或 3x12 的内核大小。我还实现了一个 4x4 内核。下面给出了所有内核及其基准性能。

void kernel3x12(double* __restrict__ c, double* __restrict__ a, double* __restrict__ b, size_t L){
    //
    __m256d a0,a1,a2, b0,b1,b2;
    //
    __m256d c00 = _mm256_setzero_pd(); __m256d c01 = _mm256_setzero_pd(); __m256d c02 = _mm256_setzero_pd();
    __m256d c10 = _mm256_setzero_pd(); __m256d c11 = _mm256_setzero_pd(); __m256d c12 = _mm256_setzero_pd();
    __m256d c20 = _mm256_setzero_pd(); __m256d c21 = _mm256_setzero_pd(); __m256d c22 = _mm256_setzero_pd();
    //
    for(size_t k=0;k<L;++k){
        //
        a0  = _mm256_set1_pd(a[ 0]);
        a1  = _mm256_set1_pd(a[ 1]);
        a2  = _mm256_set1_pd(a[ 2]);
        //
        b0  = _mm256_loadu_pd(b   );
        b1  = _mm256_loadu_pd(b+4 );
        b2  = _mm256_loadu_pd(b+8 );
        //
        c00 = _mm256_fmadd_pd(a0,b0,c00); c01 = _mm256_fmadd_pd(a0,b1,c01); c02 = _mm256_fmadd_pd(a0,b2,c02);
        c10 = _mm256_fmadd_pd(a1,b0,c10); c11 = _mm256_fmadd_pd(a1,b1,c11); c12 = _mm256_fmadd_pd(a1,b2,c12);
        c20 = _mm256_fmadd_pd(a2,b0,c20); c21 = _mm256_fmadd_pd(a2,b1,c21); c22 = _mm256_fmadd_pd(a2,b2,c22);
        //
        a+=3; b+=12;
        //
    }
    //
    _mm256_storeu_pd(c   ,_mm256_add_pd(c00,_mm256_loadu_pd(c   ))); _mm256_storeu_pd(c+4 ,_mm256_add_pd(c01,_mm256_loadu_pd(c+4 ))); _mm256_storeu_pd(c+8 ,_mm256_add_pd(c02,_mm256_loadu_pd(c+8 )));
    _mm256_storeu_pd(c+12,_mm256_add_pd(c10,_mm256_loadu_pd(c+12))); _mm256_storeu_pd(c+16,_mm256_add_pd(c11,_mm256_loadu_pd(c+16))); _mm256_storeu_pd(c+20,_mm256_add_pd(c12,_mm256_loadu_pd(c+20)));
    _mm256_storeu_pd(c+24,_mm256_add_pd(c20,_mm256_loadu_pd(c+24))); _mm256_storeu_pd(c+28,_mm256_add_pd(c21,_mm256_loadu_pd(c+28))); _mm256_storeu_pd(c+32,_mm256_add_pd(c22,_mm256_loadu_pd(c+32)));
    //
}
// 1.61+1 GigaFlops/second

void kernel_4x8(double* __restrict__ c, double* __restrict__ a, double* __restrict__ b, size_t L){
    //
    __m256d a0,a1,a2,a3, b0,b1;
    //
    __m256d c00 = _mm256_setzero_pd(); __m256d c01 = _mm256_setzero_pd();
    __m256d c10 = _mm256_setzero_pd(); __m256d c11 = _mm256_setzero_pd();
    __m256d c20 = _mm256_setzero_pd(); __m256d c21 = _mm256_setzero_pd();
    __m256d c30 = _mm256_setzero_pd(); __m256d c31 = _mm256_setzero_pd();
    //
    for(size_t k=0;k<L;++k){
        //
        a0  = _mm256_set1_pd(a[ 0]);
        a1  = _mm256_set1_pd(a[ 1]);
        a2  = _mm256_set1_pd(a[ 2]);
        a3  = _mm256_set1_pd(a[ 3]);
        //
        b0  = _mm256_loadu_pd(b   );
        b1  = _mm256_loadu_pd(b+4 );
        //
        c00 = _mm256_fmadd_pd(a0,b0,c00); c01 = _mm256_fmadd_pd(a0,b1,c01);
        c10 = _mm256_fmadd_pd(a1,b0,c10); c11 = _mm256_fmadd_pd(a1,b1,c11);
        c20 = _mm256_fmadd_pd(a2,b0,c20); c21 = _mm256_fmadd_pd(a2,b1,c21);
        c30 = _mm256_fmadd_pd(a3,b0,c20); c31 = _mm256_fmadd_pd(a3,b1,c31);
        //
        a+=4; b+=8;
        //
    }
    //
    _mm256_storeu_pd(c   ,_mm256_add_pd(c00,_mm256_loadu_pd(c   ))); _mm256_storeu_pd(c+4 ,_mm256_add_pd(c01,_mm256_loadu_pd(c+4 )));
    _mm256_storeu_pd(c+8 ,_mm256_add_pd(c10,_mm256_loadu_pd(c+8 ))); _mm256_storeu_pd(c+12,_mm256_add_pd(c11,_mm256_loadu_pd(c+12)));
    _mm256_storeu_pd(c+16,_mm256_add_pd(c20,_mm256_loadu_pd(c+16))); _mm256_storeu_pd(c+20,_mm256_add_pd(c21,_mm256_loadu_pd(c+20)));
    _mm256_storeu_pd(c+24,_mm256_add_pd(c20,_mm256_loadu_pd(c+24))); _mm256_storeu_pd(c+28,_mm256_add_pd(c21,_mm256_loadu_pd(c+28)));
    //
}
// 1.72e+1 GigaFlops/second

void kernel(double* __restrict__ c, double* __restrict__ a, double* __restrict__ b, size_t L){
    //
    __m256d b0, a0,a1,a2,a3;
    //
    __m256d c0 = _mm256_setzero_pd();
    __m256d c1 = _mm256_setzero_pd();
    __m256d c2 = _mm256_setzero_pd();
    __m256d c3 = _mm256_setzero_pd();
    //
    for(size_t k=0;k<L;++k){
        //
        b0 = _mm256_loadu_pd(b   );
        //
        a0 = _mm256_set1_pd(a[ 0]); 
        a1 = _mm256_set1_pd(a[ 1]); 
        a2 = _mm256_set1_pd(a[ 2]); 
        a3 = _mm256_set1_pd(a[ 3]); 
        //
        c0        = _mm256_fmadd_pd( a0 , b0 , c0 );
        c1        = _mm256_fmadd_pd( a1 , b0 , c1 );
        c2        = _mm256_fmadd_pd( a2 , b0 , c2 );
        c3        = _mm256_fmadd_pd( a3 , b0 , c3 );
        //
        a+=4; b+=4;
        //
    }
    //
    _mm256_storeu_pd( c    , _mm256_add_pd( c0 , _mm256_loadu_pd(c   ) ) );
    _mm256_storeu_pd( c+ 4 , _mm256_add_pd( c1 , _mm256_loadu_pd(c+ 4) ) );
    _mm256_storeu_pd( c+ 8 , _mm256_add_pd( c2 , _mm256_loadu_pd(c+ 8) ) );
    _mm256_storeu_pd( c+12 , _mm256_add_pd( c3 , _mm256_loadu_pd(c+12) ) );
    //
}
// 2.27e+1 GigaFlops/second

编译命令是

g++ -c kernel.cpp -o kernel.o -O3 -mavx -ffast-math -march=native -fno-trapping-math

我的问题是:为了获得至少 90% 的峰值性能,我必须改变什么?

为方便起见,完整的代码

/*
g++ -c kernel.cpp -o kernel.o -O3 -mavx -ffast-math -march=native -fno-trapping-math
g++ -c main.cpp -o main.o
g++ -o main.exe main.o kernel.o
./main.exe

*/

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<iomanip>
#include<random>
#include<chrono>
#include"kernel.hpp"

#define dim_m 4
#define dim_n 4

using std::cout;
using std::endl;
using std::chrono::high_resolution_clock;
using std::chrono::duration;
using std::setw;

struct rng{
    std::random_device          rd;
    std::mt19937                mt;
    std::normal_distribution<>  dist;
    rng(): rd(), mt(rd()), dist(0.0,1.0) {}
    double get(){ return dist(mt); }
};

void set_random(double* x, size_t n, rng& o){
    for(size_t i=0;i<n;++i){
        x[i] = o.get();
    }
}

void print(double* c, size_t L){
    for(int i=0;i<dim_m;++i){
        for(int j=0;j<dim_n;++j){
            std::cout << std::setw(20) << c[i*dim_n+j] << " ";
        }
        std::cout << "\n";
    }
}

void kernel_ref(double* c, double* a, double* b, size_t L){
    for(size_t k=0;k<L;++k){
        for(size_t i=0;i<dim_m;++i){
            for(size_t j=0;j<dim_n;++j){
                c[i*dim_n+j] += a[k*dim_m+i] * b[k*dim_n+j];
            }
        }
    }
}

int main(){
    
    rng o;
    
    int L = 512;
    int J = 100000;
    
    double *a, *b, *c;
    std::chrono::time_point<std::chrono::high_resolution_clock> t0,t1;
    
    // size_of(double)=8
    // aligned allocation
    /*
    a = std::aligned_alloc(L*4*sizeof(double),4*sizeof(double)); // alignment size, vector size in bytes must
    b = std::aligned_alloc(L*4*sizeof(double),4*sizeof(double));
    c = std::aligned_alloc(4*4*sizeof(double),4*sizeof(double));   // look-up for g++
    */
    a = new double[L*dim_m];
    b = new double[L*dim_n];
    c = new double[dim_m*dim_n];
    set_random(a,dim_m*L    ,o);
    set_random(b,L    *dim_n,o);
    set_random(c,dim_m*dim_n,o);
    
    int branch = 3;
    
    if(branch==1){
        kernel_ref(c,a,b,L);
    }
    if(branch==2){
        kernel(c,a,b,L);
    }
    if(branch==3){
        t0 = high_resolution_clock::now();
        for(int j=0;j<J;++j){
            kernel(c,a,b,L);
        }
        t1 = high_resolution_clock::now();
    }
    
    print(c,L);
    
    double flops = 2.0 * dim_m*dim_n * L * J;
    double time  = std::chrono::duration<double>(t1-t0).count();
    double perf  = flops / time;
    
    /*
    std::free(a);
    std::free(b);
    std::free(c);
    */
    delete[] a;
    delete[] b;
    delete[] c;
    
    cout << "\n--------------------------------------------------\n" << std::right << std::scientific;
    cout << std::left << "measures:\n";
    cout << std::left << "\t" << setw(20) << "flops"    << " : " << std::right << flops  << "\n";
    cout << std::left << "\t" << setw(20) << "time"     << " : " << std::right << time   << "\n";
    cout << std::left << "\t" << setw(20) << "perf"     << " : " << std::right << perf    * 1.0e-09 << " [GigaFlops/s] ~ "<<     16.0*3.7e+9 * 1.0e-09<<"\n"; // https://en.wikichip.org/wiki/amd/microarchitectures/zen_2#Floating_Point_Unit 
    cout << std::left << "\t" << setw(20) << "perf x32" << " : " << std::right << 32*perf * 1.0e-12 << " [TeraFlops/s] ~ "<<32.0*16.0*3.7e+9 * 1.0e-12<<"\n";
    
    return 0;
}

然后使用kernel.cpp的头文件。

结果在 Windows 10 中使用以下 MinGW 编译器。

g++ -v
Using built-in specs.
COLLECT_GCC=C:\MinGW\bin\g++.exe
COLLECT_LTO_WRAPPER=c:/mingw/bin/../libexec/gcc/mingw32/8.2.0/lto-wrapper.exe
Target: mingw32
Configured with: ../src/gcc-8.2.0/configure --build=x86_64-pc-linux-gnu --host=mingw32 --target=mingw32 --prefix=/mingw --disable-win32-registry --with-arch=i586 --with-tune=generic --enable-languages=c,c++,objc,obj-c++,fortran,ada --with-pkgversion='MinGW.org GCC-8.2.0-5' --with-gmp=/mingw --with-mpfr=/mingw --with-mpc=/mingw --enable-static --enable-shared --enable-threads --with-dwarf2 --disable-sjlj-exceptions --enable-version-specific-runtime-libs --with-libiconv-prefix=/mingw --with-libintl-prefix=/mingw --enable-libstdcxx-debug --with-isl=/mingw --enable-libgomp --disable-libvtv --enable-nls --disable-build-format-warnings
Thread model: win32
gcc version 8.2.0 (MinGW.org GCC-8.2.0-5)

PS:这个问题的目的不是为了获得一个快速的GEMM微内核,而是学习如何编写代码,以充分利用可用的计算资源。为了简化起见,将 FLOPS 内核限制在 L1-Cache 内。

C++ 性能 矩阵乘法 HPC AVX

评论

1赞 Peter Cordes 11/7/2022
这甚至会编译吗?GCC 应该拒绝,除非你也使用 ,这意味着如果调整它,你应该使用什么。顺便说一句,一个精心缓存阻止的 matmul 几乎可以使 FMA 吞吐量饱和,即使整体大小对于 L1d 来说太大。通过对 N^2 数据进行 N^3 处理,有足够的重用性使其成为可能。当然,在 L1d 中让所有输入/输出都处于热状态更容易!-mavx_mm_fmadd_pd-mfma-march=znver2
0赞 kaisong 11/7/2022
它确实编译了,性能结果是实际的。为了您的方便,我发布了整个代码。
1赞 Peter Cordes 11/7/2022
哦,里面还有一个。显然,我还不够清醒,没有注意到这一点。这让它变得毫无意义。(仅靠一个人是不够的。godbolt.org/z/r6nb7WdMK-march=native-mavx-mavx)
0赞 kaisong 11/8/2022
自从提出这个问题以来,我不停地尝试其他内核大小,在互联网上搜索咨询,并寻找与@PeterCordes这样的人的业务联系。任何对咨询/辅导感兴趣的人,请通过 [email protected] 与我联系。

答: 暂无答案