[mlir] Add sm_90a GEMM test 128x128x128 (F32 =F16*F16) with predicate (#70028)
PR #69913 added a GEMM test (128x128x128 F32 += F16 * F16) with if-statement. This PR adds the same test using predicates in PTX. Predicate support is enabled using _BasicPtxBuilderInterface_ `(nvgpu.opcode ..., predicate = %pred)`. The predicate condition is computed in `Step 2. [GPU] Elect fastest thread in CTA` inspired by cutlass. It is as follows: ``` lane_predicate = nvvm.elect.sync warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0) warp_idx_in_warp_group = warp_idx % 4 predicate = (lane_predicate & warp_idx_in_warp_group) ``` Depends on #70027 #69934 #69935 #69584
Loading
Please sign in to comment