[mlir][nvgpu] Improve `WarpgroupAccumulator` type to simplify IR (#68728)
`WarpgroupAccumulator` (or `!nvgpu.warpgroup.accumulator`) is a type that keeps the accumulator matrix that is used by warp-group level matrix multiplication. It is handy to have a special type for that as the matrix is distributed among the threads of the warp-group. However, current transformations requires to create and use multiple `WarpgroupAccumulator` if the shape of GEMM is larger than the supported shape of `wgmma.mma_async` instruction. This makes IR looks dense. This PR improves the transformation of `WarpgroupAccumulator` type in every nvgpu Op that uses it. **Example: Current GEMM in NVGPU-IR** ``` // Init %m1, %m2 = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> // GEMM %r1, %r2 = nvgpu.warpgroup.mma %descA, %descB, %m1, %m2 {transposeB}: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> -> !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> // Epilogue nvgpu.warpgroup.mma.store [%r1, %r2] to %sharedMemoryBuffer : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> into memref<128x128xf32,3> ``` **Example: This PR simplifies the IR as below:** ``` // Init %m = nvgpu.warpgroup.mma.init.accumulator -> !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> // GEMM %r1 = nvgpu.warpgroup.mma %descA, %descB, %m1 {transposeB}: !nvgpu.warpgroup.descriptor<tensor = memref<128x64xf16, 3>>, !nvgpu.warpgroup.descriptor<tensor = memref<64x128xf16, 3>>, !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> -> !nvgpu.warpgroup.accumulator<fragmented = vector<128x128xf32>> // Epilogue nvgpu.warpgroup.mma.store [%matrixD1, %matrixD2] to %sharedMemoryBuffer : !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator<fragmented = vector<64x128xf32>> into memref<128x128xf32,3> ```
Loading
Please sign in to comment