diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td index a706d67d2988912091a99197345ef1dcd34c6d8d..5a906ff2dafdf23d77e92f9c63af3fac4874c334 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -15,6 +15,8 @@ include "mlir/IR/OpBase.td" +def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>; + def Linalg_Dialect : Dialect { let name = "linalg"; let description = [{ diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 63bee92ded7ccd55e39cec999caf6aa8bb0ca98b..d54efbe37a57173c67756aee2ca4644d2204292d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -496,21 +496,25 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [ let summary = "Linalg tiled loop operation"; let description = [{ This is a loop-like operation with additional properties. The arguments - also include the input and the output tensors and the attributes to specify - the iterator types. The body region of the loop contains `subtensor` - operations applied to every tensor argument of TiledLoopOp. + also include the input and the output tensors or memrefs and the attributes + to specify the iterator types. + + Parsing TiledLoopOp will set all elements of the `iterator_types` attribute + to "parallel" type, when it is absent from the custom format. + + Tensor-based version: + + The body region of the loop contains `subtensor` operations applied to + every tensor argument of TiledLoopOp. The body region must contain exactly one block that terminates with `linalg.yield` with the operands resulting from `subtensor_insert` operations. - Parsing TiledLoopOp will set all elements of the `iterator_types` attribute - to "parallel" type, when it is absent from the custom format. - Example: ```mlir - linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4) + %0 = linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4) ins(%lhs, %rhs : tensor<24x64xi8>, tensor<24x64xi8>) outs(%out : tensor<24x64xi8>) iterators("parallel") { @@ -528,13 +532,40 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [ linalg.yield %result : tensor<24x64xi8> } ``` + + MemRef-based version: + + The body region of the loop contains `subview` operations applied to + every memref argument of TiledLoopOp. + + The body region must contain exactly one block that terminates with + `linalg.yield` with no operands. + + Example: + + ```mlir + linalg.tiled_loop (%i) = (%c0) to (%c24) step (%c4) + ins(%lhs, %rhs : memref<24x64xi8>, memref<24x64xi8>) + outs(%out : memref<24x64xi8>) + iterators("parallel") { + %lhs_sub = subview %lhs[%i, 0] [%c4, %c64] [1, 1] + : memref<24x64xi8> to memref + %rhs_sub = subview %rhs[%i, 0] [%c4, %c64] [1, 1] + : memref<24x64xi8> to memref + %out_sub = subview %out[%i, 0] [%c4, %c64] [1, 1] + : memref<24x64xi8> to memref + + %result_sub = linalg.generic ... + linalg.yield + } + ``` }]; let arguments = (ins Variadic:$lowerBound, Variadic:$upperBound, Variadic:$step, - Variadic:$inputs, - Variadic:$outputs, + Variadic:$inputs, + Variadic:$outputs, ArrayAttr:$iterator_types); let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$region); @@ -542,7 +573,7 @@ def Linalg_TiledLoopOp : Linalg_Op<"tiled_loop", [ let builders = [ OpBuilder<(ins "ValueRange":$lowerBounds, "ValueRange":$upperBounds, "ValueRange":$steps, "ValueRange":$inputs, "ValueRange":$outputs, - "ArrayRef":$iteratorTypes, + "ArrayAttr":$iteratorTypes, CArg<"function_ref", "nullptr">:$bodyBuilderFn)>, ]; diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td index f87a1eaeac8f1b7ede67b2e372e8d4f7ee4710f4..69aa7659b81cfa39b46658e1d255700d65947b0d 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td @@ -496,8 +496,6 @@ def PoolingSumOp: SingleInputPoolingBase_Op<"pooling_sum"> { //===----------------------------------------------------------------------===// // Generic Linalg ops. //===----------------------------------------------------------------------===// -def LinalgOperand: AnyTypeOf<[AnyRankedTensor, AnyStridedMemRef]>; - class LinalgOperandOfRank: Type< And<[ LinalgOperand.predicate, diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index 3b268d703a745668f44e17c9d2f5fe4ea546fd52..13cca7f19ee74e97d2bac1c415afce153ad29211 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1744,7 +1744,7 @@ static LogicalResult verify(linalg::YieldOp op) { void TiledLoopOp::build( OpBuilder &builder, OperationState &result, ValueRange lowerBounds, ValueRange upperBounds, ValueRange steps, ValueRange inputs, - ValueRange outputs, ArrayRef iteratorTypes, + ValueRange outputs, ArrayAttr iteratorTypes, function_ref bodyBuilderFn) { result.addOperands(lowerBounds); result.addOperands(upperBounds); @@ -1758,9 +1758,14 @@ void TiledLoopOp::build( static_cast(steps.size()), static_cast(inputs.size()), static_cast(outputs.size())})); - result.addAttribute(getIteratorTypesAttrName(), - builder.getStrArrayAttr(iteratorTypes)); - result.addTypes(outputs.getTypes()); + result.addAttribute(getIteratorTypesAttrName(), iteratorTypes); + + // Add output types for `RankedTensorType` output arguments. + for (Value output : outputs) { + Type outputType = output.getType(); + if (outputType.isa()) + result.addTypes(outputType); + } OpBuilder::InsertionGuard guard(builder); unsigned numIVs = steps.size(); @@ -1771,8 +1776,8 @@ void TiledLoopOp::build( if (bodyBuilderFn) { builder.setInsertionPointToStart(bodyBlock); bodyBuilderFn(builder, result.location, bodyBlock->getArguments()); + TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location); } - TiledLoopOp::ensureTerminator(*bodyRegion, builder, result.location); } static void print(OpAsmPrinter &p, TiledLoopOp op) {