[mlir][vector] Add pattern to drop unit dim from elementwise(a, b)) (#74817)
For vectors with either leading or trailing unit dim, replaces: elementwise(a, b) with: sc_a = shape_cast(a) sc_b = shape_cast(b) res = elementwise(sc_a, sc_b) return shape_cast(res) The newly inserted shape_cast Ops fold (before elementwise Op) and then restore (after elementwise Op) the unit dim. Vectors `a` and `b` are required to be rank > 1. Example: ```mlir %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32> %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32> ``` gets converted to: ```mlir %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32> %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32> %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32> %mul_sc = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32> %cast = vector.shape_cast %mul_sc : vector<1x[4]xf32> to vector<[4]xf32> ``` In practice, the bottom 2 shape_cast(s) will be folded away.
Loading
Please sign in to comment