Skip to content
Unverified Commit 52db7e27 authored by Guray Ozen's avatar Guray Ozen Committed by GitHub
Browse files

[mlir][nvgpu] Improve `WarpgroupAccumulator` type to simplify IR (#68728)

`WarpgroupAccumulator` (or `!nvgpu.warpgroup.accumulator`) is a type
that keeps the accumulator matrix that is used by warp-group level
matrix multiplication. It is handy to have a special type for that as
the matrix is distributed among the threads of the warp-group. However,
current transformations requires to create and use multiple
`WarpgroupAccumulator` if the shape of GEMM is larger than the supported
shape of `wgmma.mma_async` instruction. This makes IR looks dense.

This PR improves the transformation of `WarpgroupAccumulator` type in
every nvgpu Op that uses it.

**Example: Current GEMM in NVGPU-IR**
```
// Init
%m1, %m2 = nvgpu.warpgroup.mma.init.accumulator ->  
                    !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
                    !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>

// GEMM
%r1, %r2 = nvgpu.warpgroup.mma %descA, %descB, %m1, %m2 {transposeB}: 
      !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
      !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> 
      -> 
      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>,
      !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>  


// Epilogue 
nvgpu.warpgroup.mma.store [%r1, %r2] to %sharedMemoryBuffer
  : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, 
    !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
    into memref<128x128xf32,3>
```

**Example: This PR simplifies the IR as below:**
```
// Init
%m = nvgpu.warpgroup.mma.init.accumulator ->  
           !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>

// GEMM
%r1 = nvgpu.warpgroup.mma %descA, %descB, %m1 {transposeB}: 
      !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, 
      !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, 
      !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> 
      -> 
      !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>>  

// Epilogue 
nvgpu.warpgroup.mma.store [%matrixD1, %matrixD2] to %sharedMemoryBuffer
  : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, 
    !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>
    into memref<128x128xf32,3>
```
parent 838f2890
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment