Skip to content
Unverified Commit a00caad6 authored by Guray Ozen's avatar Guray Ozen Committed by GitHub
Browse files

[mlir] Add sm_90a GEMM test 128x128x128 (F32 =F16*F16) with predicate (#70028)

PR #69913 added a GEMM test (128x128x128 F32 += F16 * F16) with
if-statement. This PR adds the same test using predicates in PTX.
Predicate support is enabled using _BasicPtxBuilderInterface_
`(nvgpu.opcode ..., predicate = %pred)`.

The predicate condition is computed in `Step 2. [GPU] Elect fastest
thread in CTA` inspired by cutlass. It is as follows:
```
lane_predicate = nvvm.elect.sync
warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0)
warp_idx_in_warp_group = warp_idx % 4
predicate = (lane_predicate & warp_idx_in_warp_group)
```

Depends on #70027 #69934 #69935 #69584
parent 4ba50a78
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment