Allow empty dimension arrays in `linalg::inferContractionDims` (#69496)
This function was returning failure when any of the intersection sets was empty, but this is actually legitimate in "matrix times vector" cases, where some of the operands have lower dimensionality, implying unit-dimension semantics for the "missing" dimensions. Example: ```mlir func.func @transpose_extend_batch_matmul( %vec: tensor<32x128xi16>, %mat: tensor<11008x32x128xi4>) -> tensor<11008x32xi32> { %c0_i32 = arith.constant 0 : i32 %cst_0 = arith.constant 0.000000e+00 : f32 %0 = tensor.empty() : tensor<11008x32xi32> %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32xi32>) -> tensor<11008x32xi32> %2 = tensor.empty() : tensor<11008xf32> %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<11008xf32>) -> tensor<11008xf32> %batch_matmul_result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%vec, %mat : tensor<32x128xi16>, tensor<11008x32x128xi4>) outs(%1 : tensor<11008x32xi32>) { ^bb0(%in: i16, %in_3: i4, %out: i32): %19 = arith.extsi %in : i16 to i32 %20 = arith.extui %in_3 : i4 to i32 %21 = arith.muli %19, %20 : i32 %22 = arith.addi %21, %out : i32 linalg.yield %22 : i32 } -> tensor<11008x32xi32> return %batch_matmul_result : tensor<11008x32xi32> } ``` Here, we were returning failure because `ac` is empty. With this PR, we return this useful information: ``` batch: [ 1 ] m: [ ] n: [ 0 ] k: [ 2 ] ```
Loading
Please sign in to comment