Skip to content
Unverified Commit 51916f0c authored by Guray Ozen's avatar Guray Ozen Committed by GitHub
Browse files

[mlir] Add sm_90a GEMM test 128x128x128 (F32 += F16 * F16) (#69913)

This PR adds a test that performs GEMM 128x128x128 (F32 += F16 * F16).
It uses `sm_90a` features in NVGPU dialect.

Simplified algorithm is as follows:

**Prologue** 
```
mgroup = mbarriers.init x 2
tma.load ... shmem_buffer_lhs<0 x 128 x 64>
tma.load ... shmem_buffer_rhs<0 x 64 x 64>
tma.load ... shmem_buffer_rhs<0 x 64 x 64>
mbarrier.expect_tx 32768
tma.load ... shmem_buffer_lhs<1 x 128 x 64>
tma.load ... shmem_buffer_rhs<1 x 64 x 64>
tma.load ... shmem_buffer_rhs<1 x 64 x 64>
mbarrier.expect_tx 32768
```
**Mainloop**
```
matrixD = 
 for(i = 0;...2) {   
   mbarrier.try_wait [i]
   lhs = shmem_buffer_lhs<pipe x 128 x 64>
   rhs = shmem_buffer_rhs<pipe x 64 x 128>
   yield nvgpu.warpgroup.mma (lhs, rhs)

//   Expanded : nvgpu.warpgroup.mma [128][128]+=[128][64]*[64][128]
//                  wgmma.m64n128k16(A[0:64][0:16]  *  B[0:16][0:128])
//                  wgmma.m64n128k16(A[0:64][16:32] *  B[16:32][0:128])
//                  wgmma.m64n128k16(A[0:64][32:48] *  B[32:48][0:128])
//                  wgmma.m64n128k16(A[0:64][48:64] *  B[48:64][0:128])
//                  wgmma.m64n128k16(A[64:128][0:16]  *  B[0:16][0:128])
//                  wgmma.m64n128k16(A[64:128][16:32] *  B[16:32][0:128])
//                  wgmma.m64n128k16(A[64:128][32:48] *  B[32:48][0:128])
//                  wgmma.m64n128k16(A[64:128][48:64] *  B[48:64][0:128])
```

**Epilogue** 
```
//reg->shmem
warpgroup.mma.store matrixD, shmem
//shmem->glbmem
parallel-for(i=0;...128)
 parallel-for(j=0;...128)
   store shmem, globalmem
```
parent a00caad6
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment