Skip to content
Snippets Groups Projects
Commit 3444996b authored by Alexander Belyaev's avatar Alexander Belyaev
Browse files

[mlir] Add a pattern to bufferize std.index_cast.

Differential Revision: https://reviews.llvm.org/D102088
parent a3f22d02
No related branches found
No related tags found
No related merge requests found
...@@ -34,9 +34,22 @@ public: ...@@ -34,9 +34,22 @@ public:
return success(); return success();
} }
}; };
} // namespace
namespace { class BufferizeIndexCastOp : public OpConversionPattern<IndexCastOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(IndexCastOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
IndexCastOp::Adaptor adaptor(operands);
auto tensorType = op.getType().cast<RankedTensorType>();
rewriter.replaceOpWithNewOp<IndexCastOp>(
op, adaptor.in(),
MemRefType::get(tensorType.getShape(), tensorType.getElementType()));
return success();
}
};
class BufferizeSelectOp : public OpConversionPattern<SelectOp> { class BufferizeSelectOp : public OpConversionPattern<SelectOp> {
public: public:
using OpConversionPattern::OpConversionPattern; using OpConversionPattern::OpConversionPattern;
...@@ -56,8 +69,8 @@ public: ...@@ -56,8 +69,8 @@ public:
void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter, void mlir::populateStdBufferizePatterns(BufferizeTypeConverter &typeConverter,
RewritePatternSet &patterns) { RewritePatternSet &patterns) {
patterns.add<BufferizeDimOp, BufferizeSelectOp>(typeConverter, patterns.add<BufferizeDimOp, BufferizeSelectOp, BufferizeIndexCastOp>(
patterns.getContext()); typeConverter, patterns.getContext());
} }
namespace { namespace {
...@@ -68,14 +81,15 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> { ...@@ -68,14 +81,15 @@ struct StdBufferizePass : public StdBufferizeBase<StdBufferizePass> {
RewritePatternSet patterns(context); RewritePatternSet patterns(context);
ConversionTarget target(*context); ConversionTarget target(*context);
target.addLegalDialect<memref::MemRefDialect>(); target.addLegalDialect<scf::SCFDialect, StandardOpsDialect,
target.addLegalDialect<StandardOpsDialect>(); memref::MemRefDialect>();
target.addLegalDialect<scf::SCFDialect>();
populateStdBufferizePatterns(typeConverter, patterns); populateStdBufferizePatterns(typeConverter, patterns);
// We only bufferize the case of tensor selected type and scalar condition, // We only bufferize the case of tensor selected type and scalar condition,
// as that boils down to a select over memref descriptors (don't need to // as that boils down to a select over memref descriptors (don't need to
// touch the data). // touch the data).
target.addDynamicallyLegalOp<IndexCastOp>(
[&](IndexCastOp op) { return typeConverter.isLegal(op.getType()); });
target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) { target.addDynamicallyLegalOp<SelectOp>([&](SelectOp op) {
return typeConverter.isLegal(op.getType()) || return typeConverter.isLegal(op.getType()) ||
!op.condition().getType().isa<IntegerType>(); !op.condition().getType().isa<IntegerType>();
......
...@@ -24,3 +24,16 @@ func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> { ...@@ -24,3 +24,16 @@ func @select(%arg0: i1, %arg1: tensor<f32>, %arg2: tensor<f32>) -> tensor<f32> {
%0 = select %arg0, %arg1, %arg2 : tensor<f32> %0 = select %arg0, %arg1, %arg2 : tensor<f32>
return %0 : tensor<f32> return %0 : tensor<f32>
} }
// CHECK-LABEL: func @index_cast(
// CHECK-SAME: %[[TENSOR:.*]]: tensor<i32>, %[[SCALAR:.*]]: i32
func @index_cast(%tensor: tensor<i32>, %scalar: i32) -> (tensor<index>, index) {
%index_tensor = index_cast %tensor : tensor<i32> to tensor<index>
%index_scalar = index_cast %scalar : i32 to index
return %index_tensor, %index_scalar : tensor<index>, index
}
// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref<i32>
// CHECK-NEXT: %[[INDEX_MEMREF:.*]] = index_cast %[[MEMREF]]
// CHECK-SAME: memref<i32> to memref<index>
// CHECK-NEXT: %[[INDEX_TENSOR:.*]] = memref.tensor_load %[[INDEX_MEMREF]]
// CHECK: return %[[INDEX_TENSOR]]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment