[mlir][ArmSME] Fold extracts from 3D create_masks of SME-like masks (#80148)
When unrolling the reduction dimension of something like a matmul for SME, it is possible to get 3D masks, which are vectors of SME-like masks. The 2D masks for individual operations are then extracted from the 3D masks. i.e.: ```mlir %mask = vector.create_mask %nonConstantDim, %a, %b : vector<4x[4]x[4]xi1> %subMask = vector.extract %mask[2] : vector<[4]x[4]xi1> from vector<4x[4]x[4]xi1> ``` ArmSME only supports lowering 2D create_masks, so we must fold the extract into the create_mask. This can be done by checking if the extraction index is within the true region, then using that select the first dimension of the 2D mask. This is shown below. ```mlir %extractionInTrueRegion = arith.cmpi slt, %c2, %nonConstantDim : index %newMaskFrontDim = arith.select %extractionInTrueRegion, %a, %c0 : index %subMask = vector.create_mask %newMaskFrontDim, %b : vector<[4]x[4]xi1> ```
Loading
Please sign in to comment