[MLIR][NVGPU] Adding `nvgpu.warpgroup.mma` Op for Hopper GPUs (#65440)
This work introduces a new operation called `warpgroup.mma` to the NVGPU dialect of MLIR. The purpose of this operation is to facilitate warpgroup-level matrix multiply and accumulate (WGMMA) operations on Hopper GPUs with sm_90a architecture. Previously, the `nvvm.wgmma.mma_async` operation was introduced to support warpgroup-level matrix operations in NVVM dialect. This op is used multiple instances of `nvvm.wgmma.mma_async` to achieve the desired shape. The new `nvgpu.warpgroup.mma` operation abstracts this complexity and provides a higher-level interface for performing warpgroup-level matrix operations. The `nvgpu.warpgroup.mma` does followings: 1) Corresponds multiple `wgmma` instructions. 2) Iterates input matrix descriptors to achieve the desired computation shape. 3) Groups and runs `wgmma` instructions asynchronously, and eventually waits them. This are done by `wgmma.fence.aligned`, `wgmma.commit.group.sync.aligned`, and `wgmma.wait.group.sync.aligned` 4) Results fragmented matrices Here's an example usage of the `nvgpu.warpgroup.mma` operation: ``` %wgmmaResult, %wgmmaResult2 = nvgpu.warpgroup.mma %descA, %descB, %acc1, %acc2 {transposeB}: !nvgpu.wgmma.descriptor<tensor = memref<128x64xf16, 3>>, !nvgpu.wgmma.descriptor<tensor = memref<64x128xf16, 3>>, !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> -> !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>>, !nvgpu.warpgroup.accumulator< fragmented = vector<64x128xf32>> ``` The op will result following PTX: ``` wgmma.fence.sync.aligned; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f1, %f2, 62 more registers}, %descA, %descB, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f1, %f2, 62 more registers}, %descA+2, %descB+128, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f1, %f2, 62 more registers}, %descA+4, %descB+256, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f1, %f2, 62 more registers}, %descA+8, %descB+348, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f500,%f501, 62 more registers}, %descA+512, %descB, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f500,%f501, 62 more registers}, %descA+514, %descB+128, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f500,%f501, 62 more registers}, %descA+516, %descB+256, p, 1, 1, 0, 1; wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 {%f500,%f501, 62 more registers}, %descA+518, %descB+348, p, 1, 1, 0, 1; wgmma.commit_group.sync.aligned; wgmma.wait_group.sync.aligned 1; ``` The Op keeps - first 64 registers (`{%f1, %f2, 62 more registers}`) -> `%acc1` - second 64 registers (`{%f500,%f501, 62 more registers}`) -> `%acc2`.
Loading
Please sign in to comment