Skip to content
Unverified Commit c02d07fd authored by Andrzej Warzyński's avatar Andrzej Warzyński Committed by GitHub
Browse files

[mlir][vector] Add pattern to drop unit dim from elementwise(a, b)) (#74817)

For vectors with either leading or trailing unit dim, replaces:

    elementwise(a, b)

with:

    sc_a = shape_cast(a)
    sc_b = shape_cast(b)
    res = elementwise(sc_a, sc_b)
    return shape_cast(res)

The newly inserted shape_cast Ops fold (before elementwise Op) and then
restore (after elementwise Op) the unit dim. Vectors `a` and `b` are
required to be rank > 1.

Example:
```mlir
  %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
  %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
```

gets converted to:

```mlir
  %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
  %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
  %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
  %mul_sc = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
  %cast = vector.shape_cast %mul_sc : vector<1x[4]xf32> to vector<[4]xf32>
```

In practice, the bottom 2 shape_cast(s) will be folded away.
parent 9512d6d2
Loading
Loading
Loading
Loading
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment