Commit c42dd5db authored by River Riddle's avatar River Riddle
Browse files

[mlir] Add new SubElementAttr/SubElementType Interfaces

These interfaces allow for a composite attribute or type to opaquely provide access to any held attributes or types. There are several intended use cases for this interface. The first of which is to allow the printer to create aliases for non-builtin dialect attributes and types. In the future, this interface will also be extended to allow for SymbolRefAttr to be placed on other entities aside from just DictionaryAttr and ArrayAttr.

To limit potential test breakages, this revision only adds the new interfaces to the builtin attributes/types that are currently hardcoded during AsmPrinter alias generation. In a followup the remaining builtin attributes/types, and non-builtin attributes/types can be extended to support it.

Differential Revision: https://reviews.llvm.org/D102945
parent f8a1d652
......@@ -9,7 +9,7 @@
#ifndef MLIR_IR_BUILTINATTRIBUTES_H
#define MLIR_IR_BUILTINATTRIBUTES_H
#include "mlir/IR/Attributes.h"
#include "SubElementInterfaces.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/Sequence.h"
#include <complex>
......
......@@ -15,6 +15,7 @@
#define BUILTIN_ATTRIBUTES
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/SubElementInterfaces.td"
// TODO: Currently the attributes defined in this file are prefixed with
// `Builtin_`. This is to differentiate the attributes here with the ones in
......@@ -22,8 +23,9 @@ include "mlir/IR/BuiltinDialect.td"
// to this file instead.
// Base class for Builtin dialect attributes.
class Builtin_Attr<string name, string baseCppClass = "::mlir::Attribute">
: AttrDef<Builtin_Dialect, name, [], baseCppClass> {
class Builtin_Attr<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Attribute">
: AttrDef<Builtin_Dialect, name, traits, baseCppClass> {
let mnemonic = ?;
}
......@@ -62,7 +64,9 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap"> {
// ArrayAttr
//===----------------------------------------------------------------------===//
def Builtin_ArrayAttr : Builtin_Attr<"Array"> {
def Builtin_ArrayAttr : Builtin_Attr<"Array", [
DeclareAttrInterfaceMethods<SubElementAttrInterface>
]> {
let summary = "A collection of other Attribute values";
let description = [{
Syntax:
......@@ -133,7 +137,7 @@ def Builtin_ArrayAttr : Builtin_Attr<"Array"> {
//===----------------------------------------------------------------------===//
def Builtin_DenseIntOrFPElementsAttr
: Builtin_Attr<"DenseIntOrFPElements", "DenseElementsAttr"> {
: Builtin_Attr<"DenseIntOrFPElements", /*traits=*/[], "DenseElementsAttr"> {
let summary = "An Attribute containing a dense multi-dimensional array of "
"integer or floating-point values";
let description = [{
......@@ -228,7 +232,7 @@ def Builtin_DenseIntOrFPElementsAttr
//===----------------------------------------------------------------------===//
def Builtin_DenseStringElementsAttr
: Builtin_Attr<"DenseStringElements", "DenseElementsAttr"> {
: Builtin_Attr<"DenseStringElements", /*traits=*/[], "DenseElementsAttr"> {
let summary = "An Attribute containing a dense multi-dimensional array of "
"strings";
let description = [{
......@@ -277,7 +281,9 @@ def Builtin_DenseStringElementsAttr
// DictionaryAttr
//===----------------------------------------------------------------------===//
def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary"> {
def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
DeclareAttrInterfaceMethods<SubElementAttrInterface>
]> {
let summary = "An dictionary of named Attribute values";
let description = [{
Syntax:
......@@ -589,7 +595,7 @@ def Builtin_OpaqueAttr : Builtin_Attr<"Opaque"> {
//===----------------------------------------------------------------------===//
def Builtin_OpaqueElementsAttr
: Builtin_Attr<"OpaqueElements", "ElementsAttr"> {
: Builtin_Attr<"OpaqueElements", /*traits=*/[], "ElementsAttr"> {
let summary = "An opaque representation of a multi-dimensional array";
let description = [{
Syntax:
......@@ -655,7 +661,7 @@ def Builtin_OpaqueElementsAttr
//===----------------------------------------------------------------------===//
def Builtin_SparseElementsAttr
: Builtin_Attr<"SparseElements", "ElementsAttr"> {
: Builtin_Attr<"SparseElements", /*traits=*/[], "ElementsAttr"> {
let summary = "An opaque representation of a multi-dimensional array";
let description = [{
Syntax:
......@@ -892,7 +898,9 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
// TypeAttr
//===----------------------------------------------------------------------===//
def Builtin_TypeAttr : Builtin_Attr<"Type"> {
def Builtin_TypeAttr : Builtin_Attr<"Type", [
DeclareAttrInterfaceMethods<SubElementAttrInterface>
]> {
let summary = "An Attribute containing a Type";
let description = [{
Syntax:
......
......@@ -9,8 +9,7 @@
#ifndef MLIR_IR_BUILTINTYPES_H
#define MLIR_IR_BUILTINTYPES_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Types.h"
#include "SubElementInterfaces.h"
namespace llvm {
struct fltSemantics;
......
......@@ -16,14 +16,16 @@
include "mlir/IR/BuiltinDialect.td"
include "mlir/IR/BuiltinTypeInterfaces.td"
include "mlir/IR/SubElementInterfaces.td"
// TODO: Currently the types defined in this file are prefixed with `Builtin_`.
// This is to differentiate the types here with the ones in OpBase.td. We should
// remove the definitions in OpBase.td, and repoint users to this file instead.
// Base class for Builtin dialect types.
class Builtin_Type<string name, string baseCppClass = "::mlir::Type">
: TypeDef<Builtin_Dialect, name, [], baseCppClass> {
class Builtin_Type<string name, list<Trait> traits = [],
string baseCppClass = "::mlir::Type">
: TypeDef<Builtin_Dialect, name, traits, baseCppClass> {
let mnemonic = ?;
}
......@@ -66,7 +68,8 @@ def Builtin_Complex : Builtin_Type<"Complex"> {
//===----------------------------------------------------------------------===//
// Base class for Builtin dialect float types.
class Builtin_FloatType<string name> : Builtin_Type<name, "::mlir::FloatType"> {
class Builtin_FloatType<string name>
: Builtin_Type<name, /*traits=*/[], "::mlir::FloatType"> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
......@@ -118,7 +121,9 @@ def Builtin_Float128 : Builtin_FloatType<"Float128"> {
// FunctionType
//===----------------------------------------------------------------------===//
def Builtin_Function : Builtin_Type<"Function"> {
def Builtin_Function : Builtin_Type<"Function", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>
]> {
let summary = "Map from a list of inputs to a list of results";
let description = [{
Syntax:
......@@ -253,7 +258,9 @@ def Builtin_Integer : Builtin_Type<"Integer"> {
// MemRefType
//===----------------------------------------------------------------------===//
def Builtin_MemRef : Builtin_Type<"MemRef", "BaseMemRefType"> {
def Builtin_MemRef : Builtin_Type<"MemRef", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>
], "BaseMemRefType"> {
let summary = "Shaped reference to a region of memory";
let description = [{
Syntax:
......@@ -638,7 +645,9 @@ def Builtin_Opaque : Builtin_Type<"Opaque"> {
// RankedTensorType
//===----------------------------------------------------------------------===//
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "TensorType"> {
def Builtin_RankedTensor : Builtin_Type<"RankedTensor", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>
], "TensorType"> {
let summary = "Multi-dimensional array with a fixed number of dimensions";
let description = [{
Syntax:
......@@ -726,7 +735,9 @@ def Builtin_RankedTensor : Builtin_Type<"RankedTensor", "TensorType"> {
// TupleType
//===----------------------------------------------------------------------===//
def Builtin_Tuple : Builtin_Type<"Tuple"> {
def Builtin_Tuple : Builtin_Type<"Tuple", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>
]> {
let summary = "Fixed-sized collection of other types";
let description = [{
Syntax:
......@@ -793,7 +804,9 @@ def Builtin_Tuple : Builtin_Type<"Tuple"> {
// UnrankedMemRefType
//===----------------------------------------------------------------------===//
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "BaseMemRefType"> {
def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>
], "BaseMemRefType"> {
let summary = "Shaped reference, with unknown rank, to a region of memory";
let description = [{
Syntax:
......@@ -853,7 +866,9 @@ def Builtin_UnrankedMemRef : Builtin_Type<"UnrankedMemRef", "BaseMemRefType"> {
// UnrankedTensorType
//===----------------------------------------------------------------------===//
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "TensorType"> {
def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>
], "TensorType"> {
let summary = "Multi-dimensional array with unknown dimensions";
let description = [{
Syntax:
......@@ -890,7 +905,9 @@ def Builtin_UnrankedTensor : Builtin_Type<"UnrankedTensor", "TensorType"> {
// VectorType
//===----------------------------------------------------------------------===//
def Builtin_Vector : Builtin_Type<"Vector", "ShapedType"> {
def Builtin_Vector : Builtin_Type<"Vector", [
DeclareTypeInterfaceMethods<SubElementTypeInterface>
], "ShapedType"> {
let summary = "Multi-dimensional SIMD vector type";
let description = [{
Syntax:
......
......@@ -31,6 +31,13 @@ mlir_tablegen(BuiltinTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(BuiltinTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRBuiltinTypeInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS SubElementInterfaces.td)
mlir_tablegen(SubElementAttrInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(SubElementAttrInterfaces.cpp.inc -gen-attr-interface-defs)
mlir_tablegen(SubElementTypeInterfaces.h.inc -gen-type-interface-decls)
mlir_tablegen(SubElementTypeInterfaces.cpp.inc -gen-type-interface-defs)
add_public_tablegen_target(MLIRSubElementInterfacesIncGen)
set(LLVM_TARGET_DEFINITIONS TensorEncoding.td)
mlir_tablegen(TensorEncInterfaces.h.inc -gen-attr-interface-decls)
mlir_tablegen(TensorEncInterfaces.cpp.inc -gen-attr-interface-defs)
......
//===- SubElementInterfaces.h - Attr and Type SubElements -------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains interfaces and utilities for querying the sub elements of
// an attribute or type.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_INTERFACES_SUBELEMENTINTERFACES_H
#define MLIR_INTERFACES_SUBELEMENTINTERFACES_H
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Types.h"
/// Include the definitions of the sub elemnt interfaces.
#include "mlir/IR/SubElementAttrInterfaces.h.inc"
#include "mlir/IR/SubElementTypeInterfaces.h.inc"
#endif // MLIR_INTERFACES_SUBELEMENTINTERFACES_H
//===-- SubElementInterfaces.td - Sub-Element Interfaces ---*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains a set of interfaces that can be used to interface with
// sub-elements, e.g. held attributes and types, of a composite attribute or
// type.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_IR_SUBELEMENTINTERFACES_TD_
#define MLIR_IR_SUBELEMENTINTERFACES_TD_
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
// SubElementInterfaceBase
//===----------------------------------------------------------------------===//
class SubElementInterfaceBase<string interfaceName, string derivedValue> {
string cppNamespace = "::mlir";
list<InterfaceMethod> methods = [
InterfaceMethod<
/*desc=*/[{
Walk all of the immediately nested sub-attributes and sub-types. This
method does not recurse into sub elements.
}], "void", "walkImmediateSubElements",
(ins "llvm::function_ref<void(mlir::Attribute)>":$walkAttrsFn,
"llvm::function_ref<void(mlir::Type)>":$walkTypesFn)
>,
];
code extraClassDeclaration = [{
/// Walk all of the held sub-attributes.
void walkSubAttrs(llvm::function_ref<void(mlir::Attribute)> walkFn) {
walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {});
}
/// Walk all of the held sub-types.
void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
}
/// Walk all of the held sub-attributes and sub-types.
void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
llvm::function_ref<void(mlir::Type)> walkTypesFn);
}];
code extraTraitClassDeclaration = [{
/// Walk all of the held sub-attributes.
void walkSubAttrs(llvm::function_ref<void(mlir::Attribute)> walkFn) {
walkSubElements(walkFn, /*walkTypesFn=*/[](mlir::Type) {});
}
/// Walk all of the held sub-types.
void walkSubTypes(llvm::function_ref<void(mlir::Type)> walkFn) {
walkSubElements(/*walkAttrsFn=*/[](mlir::Attribute) {}, walkFn);
}
/// Walk all of the held sub-attributes and sub-types.
void walkSubElements(llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
llvm::function_ref<void(mlir::Type)> walkTypesFn) {
}] # interfaceName # " interface(" # derivedValue # [{);
interface.walkSubElements(walkAttrsFn, walkTypesFn);
}
}];
}
//===----------------------------------------------------------------------===//
// SubElementAttrInterface
//===----------------------------------------------------------------------===//
def SubElementAttrInterface
: AttrInterface<"SubElementAttrInterface">,
SubElementInterfaceBase<"SubElementAttrInterface", "$_attr"> {
let description = [{
An interface used to query and manipulate sub-elements, such as sub-types
and sub-attributes of a composite attribute.
}];
}
//===----------------------------------------------------------------------===//
// SubElementTypeInterface
//===----------------------------------------------------------------------===//
def SubElementTypeInterface
: TypeInterface<"SubElementTypeInterface">,
SubElementInterfaceBase<"SubElementTypeInterface", "$_type"> {
let description = [{
An interface used to query and manipulate sub-elements, such as sub-types
and sub-attributes of a composite type.
}];
}
#endif // MLIR_IR_SUBELEMENTINTERFACES_TD_
......@@ -22,6 +22,7 @@
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SubElementInterfaces.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
......@@ -626,14 +627,10 @@ void AliasInitializer::visit(Attribute attr, bool canBeDeferred) {
return;
}
if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
for (Attribute element : arrayAttr.getValue())
visit(element);
} else if (auto dictAttr = attr.dyn_cast<DictionaryAttr>()) {
for (const NamedAttribute &attr : dictAttr)
visit(attr.second);
} else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
visit(typeAttr.getValue());
// Check for any sub elements.
if (auto subElementInterface = attr.dyn_cast<SubElementAttrInterface>()) {
subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
[&](Type type) { visit(type); });
}
}
......@@ -645,20 +642,10 @@ void AliasInitializer::visit(Type type) {
if (succeeded(generateAlias(type, aliasToType)))
return;
// Visit several subtypes that contain types or attributes.
if (auto funcType = type.dyn_cast<FunctionType>()) {
// Visit input and result types for functions.
for (auto input : funcType.getInputs())
visit(input);
for (auto result : funcType.getResults())
visit(result);
} else if (auto shapedType = type.dyn_cast<ShapedType>()) {
visit(shapedType.getElementType());
// Visit affine maps in memref type.
if (auto memref = type.dyn_cast<MemRefType>())
for (auto map : memref.getAffineMaps())
visit(AffineMapAttr::get(map));
// Check for any sub elements.
if (auto subElementInterface = type.dyn_cast<SubElementTypeInterface>()) {
subElementInterface.walkSubElements([&](Attribute attr) { visit(attr); },
[&](Type type) { visit(type); });
}
}
......
......@@ -42,6 +42,17 @@ void BuiltinDialect::registerAttributes() {
UnitAttr>();
}
//===----------------------------------------------------------------------===//
// ArrayAttr
//===----------------------------------------------------------------------===//
void ArrayAttr::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
for (Attribute attr : getValue())
walkAttrsFn(attr);
}
//===----------------------------------------------------------------------===//
// DictionaryAttr
//===----------------------------------------------------------------------===//
......@@ -197,6 +208,13 @@ DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
return Base::get(context, ArrayRef<NamedAttribute>());
}
void DictionaryAttr::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
for (Attribute attr : llvm::make_second_range(getValue()))
walkAttrsFn(attr);
}
//===----------------------------------------------------------------------===//
// StringAttr
//===----------------------------------------------------------------------===//
......@@ -1370,3 +1388,13 @@ std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
{&*std::next(sparseIndexValues.begin(), i * rank), rank}));
return flatSparseIndices;
}
//===----------------------------------------------------------------------===//
// TypeAttr
//===----------------------------------------------------------------------===//
void TypeAttr::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getValue());
}
......@@ -199,6 +199,13 @@ FunctionType::getWithoutArgsAndResults(ArrayRef<unsigned> argIndices,
return get(getContext(), newInputTypes, newResultTypes);
}
void FunctionType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
for (Type type : llvm::concat<const Type>(getInputs(), getResults()))
walkTypesFn(type);
}
//===----------------------------------------------------------------------===//
// OpaqueType
//===----------------------------------------------------------------------===//
......@@ -419,6 +426,12 @@ VectorType VectorType::scaleElementBitwidth(unsigned scale) {
return VectorType();
}
void VectorType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getElementType());
}
//===----------------------------------------------------------------------===//
// TensorType
//===----------------------------------------------------------------------===//
......@@ -459,6 +472,12 @@ RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
return checkTensorElementType(emitError, elementType);
}
void RankedTensorType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getElementType());
}
//===----------------------------------------------------------------------===//
// UnrankedTensorType
//===----------------------------------------------------------------------===//
......@@ -469,6 +488,12 @@ UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
return checkTensorElementType(emitError, elementType);
}
void UnrankedTensorType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getElementType());
}
//===----------------------------------------------------------------------===//
// BaseMemRefType
//===----------------------------------------------------------------------===//
......@@ -612,6 +637,15 @@ LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
return success();
}
void MemRefType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getElementType());
walkAttrsFn(getMemorySpace());
for (AffineMap map : getAffineMaps())
walkAttrsFn(AffineMapAttr::get(map));
}
//===----------------------------------------------------------------------===//
// UnrankedMemRefType
//===----------------------------------------------------------------------===//
......@@ -779,6 +813,13 @@ LogicalResult mlir::getStridesAndOffset(MemRefType t,
return success();
}
void UnrankedMemRefType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
walkTypesFn(getElementType());
walkAttrsFn(getMemorySpace());
}
//===----------------------------------------------------------------------===//
/// TupleType
//===----------------------------------------------------------------------===//
......@@ -802,6 +843,13 @@ void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
/// Return the number of element types.
size_t TupleType::size() const { return getImpl()->size(); }
void TupleType::walkImmediateSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) const {
for (Type type : getTypes())
walkTypesFn(type);
}
//===----------------------------------------------------------------------===//
// Type Utilities
//===----------------------------------------------------------------------===//
......
......@@ -21,6 +21,7 @@ add_mlir_library(MLIRIR
PatternMatch.cpp
Region.cpp
RegionKindInterface.cpp
SubElementInterfaces.cpp
SymbolTable.cpp
TensorEncoding.cpp
Types.cpp
......@@ -46,6 +47,7 @@ add_mlir_library(MLIRIR
MLIROpAsmInterfaceIncGen
MLIRRegionKindInterfaceIncGen
MLIRSideEffectInterfacesIncGen
MLIRSubElementInterfacesIncGen
MLIRSymbolInterfacesIncGen
MLIRTensorEncodingIncGen
......
//===- SubElementInterfaces.cpp - Attr and Type SubElement Interfaces -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/IR/SubElementInterfaces.h"
using namespace mlir;
template <typename InterfaceT>
static void walkSubElementsImpl(InterfaceT interface,
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
interface.walkImmediateSubElements(
[&](Attribute attr) {
// Guard against potentially null inputs. This removes the need for the
// derived attribute/type to do it.
if (!attr)
return;
// Walk any sub elements first.
if (auto interface = attr.dyn_cast<SubElementAttrInterface>())
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
// Walk this attribute.
walkAttrsFn(attr);
},
[&](Type type) {
// Guard against potentially null inputs. This removes the need for the
// derived attribute/type to do it.
if (!type)
return;
// Walk any sub elements first.
if (auto interface = type.dyn_cast<SubElementTypeInterface>())
walkSubElementsImpl(interface, walkAttrsFn, walkTypesFn);
// Walk this type.
walkTypesFn(type);
});
}
void SubElementAttrInterface::walkSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
}
void SubElementTypeInterface::walkSubElements(
function_ref<void(Attribute)> walkAttrsFn,
function_ref<void(Type)> walkTypesFn) {
assert(walkAttrsFn && walkTypesFn && "expected valid walk functions");
walkSubElementsImpl(*this, walkAttrsFn, walkTypesFn);
}
//===----------------------------------------------------------------------===//
// SubElementInterface Tablegen definitions
//===----------------------------------------------------------------------===//
#include "mlir/IR/SubElementAttrInterfaces.cpp.inc"
#include "mlir/IR/SubElementTypeInterfaces.cpp.inc"
......@@ -4,6 +4,7 @@ add_mlir_unittest(MLIRIRTests
MemRefTypeTest.cpp
OperationSupportTest.cpp