[mlir] Add sm_90a GEMM test 128x128x128 (F32 += F16 * F16) (#69913)
This PR adds a test that performs GEMM 128x128x128 (F32 += F16 * F16). It uses `sm_90a` features in NVGPU dialect. Simplified algorithm is as follows: **Prologue** ``` mgroup = mbarriers.init x 2 tma.load ... shmem_buffer_lhs<0 x 128 x 64> tma.load ... shmem_buffer_rhs<0 x 64 x 64> tma.load ... shmem_buffer_rhs<0 x 64 x 64> mbarrier.expect_tx 32768 tma.load ... shmem_buffer_lhs<1 x 128 x 64> tma.load ... shmem_buffer_rhs<1 x 64 x 64> tma.load ... shmem_buffer_rhs<1 x 64 x 64> mbarrier.expect_tx 32768 ``` **Mainloop** ``` matrixD = for(i = 0;...2) { mbarrier.try_wait [i] lhs = shmem_buffer_lhs<pipe x 128 x 64> rhs = shmem_buffer_rhs<pipe x 64 x 128> yield nvgpu.warpgroup.mma (lhs, rhs) // Expanded : nvgpu.warpgroup.mma [128][128]+=[128][64]*[64][128] // wgmma.m64n128k16(A[0:64][0:16] * B[0:16][0:128]) // wgmma.m64n128k16(A[0:64][16:32] * B[16:32][0:128]) // wgmma.m64n128k16(A[0:64][32:48] * B[32:48][0:128]) // wgmma.m64n128k16(A[0:64][48:64] * B[48:64][0:128]) // wgmma.m64n128k16(A[64:128][0:16] * B[0:16][0:128]) // wgmma.m64n128k16(A[64:128][16:32] * B[16:32][0:128]) // wgmma.m64n128k16(A[64:128][32:48] * B[32:48][0:128]) // wgmma.m64n128k16(A[64:128][48:64] * B[48:64][0:128]) ``` **Epilogue** ``` //reg->shmem warpgroup.mma.store matrixD, shmem //shmem->glbmem parallel-for(i=0;...128) parallel-for(j=0;...128) store shmem, globalmem ```
Loading
Please sign in to comment