Skip to content
Commit 9a795f0c authored by Manish Gupta's avatar Manish Gupta
Browse files

[mlir][Vector] Adds a pattern to fold `arith.extf` into `vector.contract`

Consider mixed precision data type, i.e., F16 input lhs, F16 input rhs, F32 accumulation, and F32 output. This is typically written as F32 <= F16*F16 + F32.

During vectorization from linalg to vector for mixed precision data type (F32 <= F16*F16 + F32), linalg.matmul introduces arith.extf on input lhs and rhs operands.

"linalg.matmul"(%lhs, %rhs, %acc) ({
      ^bb0(%arg1: f16, %arg2: f16, %arg3: f32):
        %lhs_f32 = "arith.extf"(%arg1) : (f16) -> f32
        %rhs_f32 = "arith.extf"(%arg2) : (f16) -> f32
       %mul = "arith.mulf"(%lhs_f32, %rhs_f32) : (f32, f32) -> f32
        %acc = "arith.addf"(%arg3, %mul) : (f32, f32) -> f32
      "linalg.yield"(%acc) : (f32) -> ()
    })
There are backend that natively supports mixed-precision data type and does not need the arith.extf. For example, NVIDIA A100 GPU has mma.sync.aligned.*.f32.f16.f16.f32 that can support mixed-precision data type. However, the presence of arith.extf in the IR, introduces the unnecessary casting targeting F32 Tensor Cores instead of F16 Tensor Cores for NVIDIA backend. This patch adds a folding pattern to fold arith.extf into vector.contract

Differential Revision: https://reviews.llvm.org/D151918
parent f04cf6b7
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment