diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td index 8f3670b29d748c52a8d1dda1e70c7c7a324ca672..cde53725b4a4f2dba068989c423427ee0c3d97c5 100644 --- a/flang/include/flang/Optimizer/Dialect/FIROps.td +++ b/flang/include/flang/Optimizer/Dialect/FIROps.td @@ -267,27 +267,27 @@ class fir_AllocatableOp traits = []> : static constexpr llvm::StringRef inType() { return "in_type"; } static constexpr llvm::StringRef lenpName() { return "len_param_count"; } mlir::Type getAllocatedType(); - + bool hasLenParams() { return bool{(*this)->getAttr(lenpName())}; } - + unsigned numLenParams() { if (auto val = (*this)->getAttrOfType(lenpName())) return val.getInt(); return 0; } - + operand_range getLenParams() { return {operand_begin(), operand_begin() + numLenParams()}; } - + unsigned numShapeOperands() { return operand_end() - operand_begin() + numLenParams(); } - + operand_range getShapeOperands() { return {operand_begin() + numLenParams(), operand_end()}; } - + static mlir::Type getRefTy(mlir::Type ty); /// Get the input type of the allocation @@ -1131,7 +1131,7 @@ def fir_EmboxCharOp : fir_Op<"emboxchar", [NoSideEffect]> { }]; let arguments = (ins AnyReferenceLike:$memref, AnyIntegerLike:$len); - + let results = (outs fir_BoxCharType); let assemblyFormat = [{ @@ -1563,7 +1563,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> { p.printFunctionalType((*this)->getOperandTypes(), (*this)->getResultTypes()); }]; - + let verifier = [{ auto refTy = ref().getType(); if (fir::isa_ref_type(refTy)) { @@ -1598,7 +1598,7 @@ def fir_CoordinateOp : fir_Op<"coordinate_of", [NoSideEffect]> { CArg<"ArrayRef", "{}">:$attrs)>, OpBuilderDAG<(ins "Type":$type, "ValueRange":$operands, CArg<"ArrayRef", "{}">:$attrs)>]; - + let extraClassDeclaration = [{ static constexpr llvm::StringRef baseType() { return "base_type"; } mlir::Type getBaseType(); @@ -1686,7 +1686,7 @@ def fir_FieldIndexOp : fir_OneResultOp<"field_index", [NoSideEffect]> { let printer = [{ p << getOperationName() << ' ' - << (*this)->getAttrOfType(fieldAttrName()).getValue() + << (*this)->getAttrOfType(fieldAttrName()).getValue() << ", " << (*this)->getAttr(typeAttrName()); if (getNumOperands()) { p << '('; @@ -2007,7 +2007,7 @@ def fir_IterWhileOp : region_Op<"iterate_while", CArg<"ValueRange", "llvm::None">:$iterArgs, CArg<"ArrayRef", "{}">:$attributes)> ]; - + let extraClassDeclaration = [{ mlir::Block *getBody() { return ®ion().front(); } mlir::Value getIterateVar() { return getBody()->getArgument(1); } @@ -2276,11 +2276,11 @@ def fir_ConstfOp : fir_Op<"constf", [NoSideEffect]> { }]; let arguments = (ins FirRealAttr:$constant); - + let results = (outs fir_RealType:$res); let assemblyFormat = "`(` $constant `)` attr-dict `:` type($res)"; - + let verifier = [{ if (!getType().isa()) return emitOpError("must be a !fir.real type"); @@ -2357,7 +2357,7 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> { }]; let results = (outs fir_ComplexType); - + let parser = [{ fir::RealAttr realp; fir::RealAttr imagp; @@ -2455,7 +2455,7 @@ def fir_CmpcOp : fir_Op<"cmpc", def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> { let summary = "convert a symbol to an SSA value"; - + let description = [{ Convert a symbol (a function or global reference) to an SSA-value to be used in other Operations. @@ -2474,7 +2474,7 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> { def fir_ConvertOp : fir_OneResultOp<"convert", [NoSideEffect]> { let summary = "encapsulates all Fortran scalar type conversions"; - + let description = [{ Generalized type conversion. Convert the ssa value from type T to type U. Not all pairs of types have conversions. When types T and U are the same @@ -2705,7 +2705,7 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> { mlir::Type resultType() { return fir::AllocaOp::wrapResultType(getType()); } - + /// Return the initializer attribute if it exists, or a null attribute. Attribute getValueOrNull() { return initVal().getValueOr(Attribute()); } @@ -2728,9 +2728,9 @@ def fir_GlobalOp : fir_Op<"global", [IsolatedFromAbove, Symbol]> { } mlir::FlatSymbolRefAttr getSymbol() { - return mlir::FlatSymbolRefAttr::get( + return mlir::FlatSymbolRefAttr::get(getContext(), (*this)->getAttrOfType( - mlir::SymbolTable::getSymbolAttrName()).getValue(), getContext()); + mlir::SymbolTable::getSymbolAttrName()).getValue()); } }]; } @@ -2772,7 +2772,7 @@ def fir_GlobalLenOp : fir_Op<"global_len", []> { }]; let printer = [{ - p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName()) + p << getOperationName() << ' ' << (*this)->getAttr(lenParamAttrName()) << ", " << (*this)->getAttr(intAttrName()); }]; diff --git a/flang/lib/Lower/FIRBuilder.cpp b/flang/lib/Lower/FIRBuilder.cpp index 3f470d61c28668bae805c0df88427253024d8558..0a8473b73268bb23fa56a7d89fea8ced84a8d622 100644 --- a/flang/lib/Lower/FIRBuilder.cpp +++ b/flang/lib/Lower/FIRBuilder.cpp @@ -173,7 +173,7 @@ mlir::Value Fortran::lower::FirOpBuilder::createConvert(mlir::Location loc, fir::StringLitOp Fortran::lower::FirOpBuilder::createStringLit( mlir::Location loc, mlir::Type eleTy, llvm::StringRef data) { - auto strAttr = mlir::StringAttr::get(data, getContext()); + auto strAttr = mlir::StringAttr::get(getContext(), data); auto valTag = mlir::Identifier::get(fir::StringLitOp::value(), getContext()); mlir::NamedAttribute dataAttr(valTag, strAttr); auto sizeTag = mlir::Identifier::get(fir::StringLitOp::size(), getContext()); diff --git a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp index 3883ce2ed0c8b5e4ad73115ee223c05b9ef62810..8523a83711927a010015f18385c860961a676ef5 100644 --- a/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp @@ -107,7 +107,7 @@ private: ModuleOp module) { auto *context = module.getContext(); if (module.lookupSymbol("printf")) - return SymbolRefAttr::get("printf", context); + return SymbolRefAttr::get(context, "printf"); // Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` @@ -120,7 +120,7 @@ private: PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); rewriter.create(module.getLoc(), "printf", llvmFnType); - return SymbolRefAttr::get("printf", context); + return SymbolRefAttr::get(context, "printf"); } /// Return a value representing an access into a global string with the given diff --git a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp index 3883ce2ed0c8b5e4ad73115ee223c05b9ef62810..8523a83711927a010015f18385c860961a676ef5 100644 --- a/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp @@ -107,7 +107,7 @@ private: ModuleOp module) { auto *context = module.getContext(); if (module.lookupSymbol("printf")) - return SymbolRefAttr::get("printf", context); + return SymbolRefAttr::get(context, "printf"); // Create a function declaration for printf, the signature is: // * `i32 (i8*, ...)` @@ -120,7 +120,7 @@ private: PatternRewriter::InsertionGuard insertGuard(rewriter); rewriter.setInsertionPointToStart(module.getBody()); rewriter.create(module.getLoc(), "printf", llvmFnType); - return SymbolRefAttr::get("printf", context); + return SymbolRefAttr::get(context, "printf"); } /// Return a value representing an access into a global string with the given diff --git a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h index 794417e996521fdfc2a1cffa6e706ddf28731438..b903c0928d1b425f1023489d99dc615beb601b5e 100644 --- a/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h +++ b/mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h @@ -31,7 +31,7 @@ inline bool isRowMajorMatmul(ArrayAttr indexingMaps) { auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {m, n}, context)); - auto maps = ArrayAttr::get({mapA, mapB, mapC}, context); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); return indexingMaps == maps; } @@ -42,7 +42,7 @@ inline bool isColumnMajorMatmul(ArrayAttr indexingMaps) { auto mapA = AffineMapAttr::get(AffineMap::get(3, 0, {k, n}, context)); auto mapB = AffineMapAttr::get(AffineMap::get(3, 0, {m, k}, context)); auto mapC = AffineMapAttr::get(AffineMap::get(3, 0, {n, m}, context)); - auto maps = ArrayAttr::get({mapA, mapB, mapC}, context); + auto maps = ArrayAttr::get(context, {mapA, mapB, mapC}); return indexingMaps == maps; } diff --git a/mlir/include/mlir/IR/BuiltinAttributes.h b/mlir/include/mlir/IR/BuiltinAttributes.h index 34e7e8cfce12cb6fefb042eab3c762cf9eedab84..571c9126f1631c05a314c40f427e9fa45d711a40 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.h +++ b/mlir/include/mlir/IR/BuiltinAttributes.h @@ -69,7 +69,7 @@ public: using Base::Base; using ValueType = ArrayRef; - static ArrayAttr get(ArrayRef value, MLIRContext *context); + static ArrayAttr get(MLIRContext *context, ArrayRef value); ArrayRef getValue() const; Attribute operator[](unsigned idx) const; @@ -126,8 +126,8 @@ public: /// attributes. This method assumes that the provided list is unordered. If /// the caller can guarantee that the attributes are ordered by name, /// getWithSorted should be used instead. - static DictionaryAttr get(ArrayRef value, - MLIRContext *context); + static DictionaryAttr get(MLIRContext *context, + ArrayRef value); /// Construct a dictionary with an array of values that is known to already be /// sorted by name and uniqued. @@ -250,7 +250,7 @@ public: using Attribute::Attribute; using ValueType = bool; - static BoolAttr get(bool value, MLIRContext *context); + static BoolAttr get(MLIRContext *context, bool value); /// Enable conversion to IntegerAttr. This uses conversion vs. inheritance to /// avoid bringing in all of IntegerAttrs methods. @@ -292,8 +292,8 @@ public: using Base::Base; /// Get or create a new OpaqueAttr with the provided dialect and string data. - static OpaqueAttr get(Identifier dialect, StringRef attrData, Type type, - MLIRContext *context); + static OpaqueAttr get(MLIRContext *context, Identifier dialect, + StringRef attrData, Type type); /// Get or create a new OpaqueAttr with the provided dialect and string data. /// If the given identifier is not a valid namespace for a dialect, then a @@ -325,7 +325,7 @@ public: using ValueType = StringRef; /// Get an instance of a StringAttr with the given string. - static StringAttr get(StringRef bytes, MLIRContext *context); + static StringAttr get(MLIRContext *context, StringRef bytes); /// Get an instance of a StringAttr with the given string and Type. static StringAttr get(StringRef bytes, Type type); @@ -348,13 +348,12 @@ public: using Base::Base; /// Construct a symbol reference for the given value name. - static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx); + static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value); /// Construct a symbol reference for the given value name, and a set of nested /// references that are further resolve to a nested symbol. - static SymbolRefAttr get(StringRef value, - ArrayRef references, - MLIRContext *ctx); + static SymbolRefAttr get(MLIRContext *ctx, StringRef value, + ArrayRef references); /// Returns the name of the top level symbol reference, i.e. the root of the /// reference path. @@ -377,8 +376,8 @@ public: using ValueType = StringRef; /// Construct a symbol reference for the given value name. - static FlatSymbolRefAttr get(StringRef value, MLIRContext *ctx) { - return SymbolRefAttr::get(value, ctx); + static FlatSymbolRefAttr get(MLIRContext *ctx, StringRef value) { + return SymbolRefAttr::get(ctx, value); } /// Returns the name of the held symbol reference. diff --git a/mlir/include/mlir/IR/FunctionSupport.h b/mlir/include/mlir/IR/FunctionSupport.h index be8a68979203073c9a8af7864b529981642271d4..588a5f7ed62c4f8051984b99a54dca9dcfd99698 100644 --- a/mlir/include/mlir/IR/FunctionSupport.h +++ b/mlir/include/mlir/IR/FunctionSupport.h @@ -569,7 +569,7 @@ void FunctionLike::setArgAttrs( if (attributes.empty()) return (void)static_cast(this)->removeAttr(nameOut); Operation *op = this->getOperation(); - op->setAttr(nameOut, DictionaryAttr::get(attributes, op->getContext())); + op->setAttr(nameOut, DictionaryAttr::get(op->getContext(), attributes)); } template diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 45b9c490fd212904395fe32a8c93cf2c5df18608..70cd55dbbb138895af8d8a14c5ebdf7c87174d1b 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -315,7 +315,7 @@ public: attrs = newAttrs; } void setAttrs(ArrayRef newAttrs) { - setAttrs(DictionaryAttr::get(newAttrs, getContext())); + setAttrs(DictionaryAttr::get(getContext(), newAttrs)); } /// Return the specified attribute if present, null otherwise. diff --git a/mlir/include/mlir/IR/SymbolInterfaces.td b/mlir/include/mlir/IR/SymbolInterfaces.td index c5f252e45a2018b4f79fabff5a11c69bc76003fe..a7b1fd8cfe6433529f062266f66565d0804b129b 100644 --- a/mlir/include/mlir/IR/SymbolInterfaces.td +++ b/mlir/include/mlir/IR/SymbolInterfaces.td @@ -44,7 +44,7 @@ def Symbol : OpInterface<"SymbolOpInterface"> { /*defaultImplementation=*/[{ this->getOperation()->setAttr( mlir::SymbolTable::getSymbolAttrName(), - StringAttr::get(name, this->getOperation()->getContext())); + StringAttr::get(this->getOperation()->getContext(), name)); }] >, InterfaceMethod<"Gets the visibility of this symbol.", diff --git a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp index 90ed9cb0ad02d1be62842f48c7b5299380a333cd..9e61e3a9d6e05ac9974381a0eb56d72c20731298 100644 --- a/mlir/lib/CAPI/IR/BuiltinAttributes.cpp +++ b/mlir/lib/CAPI/IR/BuiltinAttributes.cpp @@ -42,9 +42,9 @@ bool mlirAttributeIsAArray(MlirAttribute attr) { MlirAttribute mlirArrayAttrGet(MlirContext ctx, intptr_t numElements, MlirAttribute const *elements) { SmallVector attrs; - return wrap(ArrayAttr::get( - unwrapList(static_cast(numElements), elements, attrs), - unwrap(ctx))); + return wrap( + ArrayAttr::get(unwrap(ctx), unwrapList(static_cast(numElements), + elements, attrs))); } intptr_t mlirArrayAttrGetNumElements(MlirAttribute attr) { @@ -71,7 +71,7 @@ MlirAttribute mlirDictionaryAttrGet(MlirContext ctx, intptr_t numElements, attributes.emplace_back( Identifier::get(unwrap(elements[i].name), unwrap(ctx)), unwrap(elements[i].attribute)); - return wrap(DictionaryAttr::get(attributes, unwrap(ctx))); + return wrap(DictionaryAttr::get(unwrap(ctx), attributes)); } intptr_t mlirDictionaryAttrGetNumElements(MlirAttribute attr) { @@ -137,7 +137,7 @@ bool mlirAttributeIsABool(MlirAttribute attr) { } MlirAttribute mlirBoolAttrGet(MlirContext ctx, int value) { - return wrap(BoolAttr::get(value, unwrap(ctx))); + return wrap(BoolAttr::get(unwrap(ctx), value)); } bool mlirBoolAttrGetValue(MlirAttribute attr) { @@ -163,9 +163,9 @@ bool mlirAttributeIsAOpaque(MlirAttribute attr) { MlirAttribute mlirOpaqueAttrGet(MlirContext ctx, MlirStringRef dialectNamespace, intptr_t dataLength, const char *data, MlirType type) { - return wrap( - OpaqueAttr::get(Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), - StringRef(data, dataLength), unwrap(type), unwrap(ctx))); + return wrap(OpaqueAttr::get( + unwrap(ctx), Identifier::get(unwrap(dialectNamespace), unwrap(ctx)), + StringRef(data, dataLength), unwrap(type))); } MlirStringRef mlirOpaqueAttrGetDialectNamespace(MlirAttribute attr) { @@ -185,7 +185,7 @@ bool mlirAttributeIsAString(MlirAttribute attr) { } MlirAttribute mlirStringAttrGet(MlirContext ctx, MlirStringRef str) { - return wrap(StringAttr::get(unwrap(str), unwrap(ctx))); + return wrap(StringAttr::get(unwrap(ctx), unwrap(str))); } MlirAttribute mlirStringAttrTypedGet(MlirType type, MlirStringRef str) { @@ -211,7 +211,7 @@ MlirAttribute mlirSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol, refs.reserve(numReferences); for (intptr_t i = 0; i < numReferences; ++i) refs.push_back(unwrap(references[i]).cast()); - return wrap(SymbolRefAttr::get(unwrap(symbol), refs, unwrap(ctx))); + return wrap(SymbolRefAttr::get(unwrap(ctx), unwrap(symbol), refs)); } MlirStringRef mlirSymbolRefAttrGetRootReference(MlirAttribute attr) { @@ -241,7 +241,7 @@ bool mlirAttributeIsAFlatSymbolRef(MlirAttribute attr) { } MlirAttribute mlirFlatSymbolRefAttrGet(MlirContext ctx, MlirStringRef symbol) { - return wrap(FlatSymbolRefAttr::get(unwrap(symbol), unwrap(ctx))); + return wrap(FlatSymbolRefAttr::get(unwrap(ctx), unwrap(symbol))); } MlirStringRef mlirFlatSymbolRefAttrGetValue(MlirAttribute attr) { diff --git a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp index 447b00567776d5199ab554a76a6768d8564fbc85..1b9e36180114604cf2e30de021027857f3aba3e2 100644 --- a/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp +++ b/mlir/lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp @@ -148,7 +148,7 @@ StringAttr GpuKernelToBlobPass::translateGPUModuleToBinaryAnnotation( auto blob = convertModuleToBlob(llvmModule, loc, name); if (!blob) return {}; - return StringAttr::get({blob->data(), blob->size()}, loc->getContext()); + return StringAttr::get(loc->getContext(), {blob->data(), blob->size()}); } std::unique_ptr> diff --git a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp index 887d3e798af74629dc56f50c73e7e4c91ea0822b..5b62ca455dea305b67c8399759d60856a40e6f1b 100644 --- a/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp +++ b/mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp @@ -177,12 +177,12 @@ void ConvertGpuLaunchFuncToVulkanLaunchFunc::convertGpuLaunchFunc( // Set SPIR-V binary shader data as an attribute. vulkanLaunchCallOp->setAttr( kSPIRVBlobAttrName, - StringAttr::get({binary.data(), binary.size()}, loc->getContext())); + StringAttr::get(loc->getContext(), {binary.data(), binary.size()})); // Set entry point name as an attribute. vulkanLaunchCallOp->setAttr( kSPIRVEntryPointAttrName, - StringAttr::get(launchOp.getKernelName(), loc->getContext())); + StringAttr::get(loc->getContext(), launchOp.getKernelName())); launchOp.erase(); } diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp index 87026e4483e6eaf35fb3e91e59e57b33fea98c83..29cf42205a563408b8ea18b8a66d9fa29dd5d4c8 100644 --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -687,8 +687,8 @@ public: rewriter.create(loc, llvmI32Type, executionModeAttr); structValue = rewriter.create( loc, structType, structValue, executionMode, - ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 0)}, - context)); + ArrayAttr::get(context, + {rewriter.getIntegerAttr(rewriter.getI32Type(), 0)})); // Insert extra operands if they exist into execution mode info struct. for (unsigned i = 0, e = values.size(); i < e; ++i) { @@ -696,9 +696,9 @@ public: Value entry = rewriter.create(loc, llvmI32Type, attr); structValue = rewriter.create( loc, structType, structValue, entry, - ArrayAttr::get({rewriter.getIntegerAttr(rewriter.getI32Type(), 1), - rewriter.getIntegerAttr(rewriter.getI32Type(), i)}, - context)); + ArrayAttr::get(context, + {rewriter.getIntegerAttr(rewriter.getI32Type(), 1), + rewriter.getIntegerAttr(rewriter.getI32Type(), i)})); } rewriter.create(loc, ArrayRef({structValue})); rewriter.eraseOp(op); @@ -1297,17 +1297,17 @@ public: switch (funcOp.function_control()) { #define DISPATCH(functionControl, llvmAttr) \ case functionControl: \ - newFuncOp->setAttr("passthrough", ArrayAttr::get({llvmAttr}, context)); \ + newFuncOp->setAttr("passthrough", ArrayAttr::get(context, {llvmAttr})); \ break; DISPATCH(spirv::FunctionControl::Inline, - StringAttr::get("alwaysinline", context)); + StringAttr::get(context, "alwaysinline")); DISPATCH(spirv::FunctionControl::DontInline, - StringAttr::get("noinline", context)); + StringAttr::get(context, "noinline")); DISPATCH(spirv::FunctionControl::Pure, - StringAttr::get("readonly", context)); + StringAttr::get(context, "readonly")); DISPATCH(spirv::FunctionControl::Const, - StringAttr::get("readnone", context)); + StringAttr::get(context, "readnone")); #undef DISPATCH diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp index 794f4a5d6c1e1126a1fb103b12386b9ce2ea5648..ea0a4259637cb2036818d74a7f480ac482d863dd 100644 --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -4016,7 +4016,7 @@ struct LLVMLoweringPass : public ConvertStandardToLLVMBase { if (failed(applyPartialConversion(m, target, std::move(patterns)))) signalPassFailure(); m->setAttr(LLVM::LLVMDialect::getDataLayoutAttrName(), - StringAttr::get(this->dataLayout, m.getContext())); + StringAttr::get(m.getContext(), this->dataLayout)); } }; } // end namespace diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 9e88250e2cab14f921f8e219191c1384530ce2ef..683de815a54e468858ee694fce32fc481a1704f9 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -762,7 +762,7 @@ public: if (positionAttrs.size() > 1) { auto oneDVectorType = reducedVectorTypeBack(vectorType); auto nMinusOnePositionAttrs = - ArrayAttr::get(positionAttrs.drop_back(), context); + ArrayAttr::get(context, positionAttrs.drop_back()); extracted = rewriter.create( loc, typeConverter->convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); @@ -871,7 +871,7 @@ public: if (positionAttrs.size() > 1) { oneDVectorType = reducedVectorTypeBack(destVectorType); auto nMinusOnePositionAttrs = - ArrayAttr::get(positionAttrs.drop_back(), context); + ArrayAttr::get(context, positionAttrs.drop_back()); extracted = rewriter.create( loc, typeConverter->convertType(oneDVectorType), extracted, nMinusOnePositionAttrs); @@ -887,7 +887,7 @@ public: // Potential insertion of resulting 1-D vector into array. if (positionAttrs.size() > 1) { auto nMinusOnePositionAttrs = - ArrayAttr::get(positionAttrs.drop_back(), context); + ArrayAttr::get(context, positionAttrs.drop_back()); inserted = rewriter.create(loc, llvmResultType, adaptor.dest(), inserted, nMinusOnePositionAttrs); diff --git a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp index c1d0820e1cc72524d7061ae28da09a9765b564ab..6ccb59aff35a3f11fe9d79149f8d948416491434 100644 --- a/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/ParallelLoopMapper.cpp @@ -53,7 +53,7 @@ LogicalResult setMappingAttr(scf::ParallelOp ploopOp, } ArrayRef mappingAsAttrs(mapping.data(), mapping.size()); ploopOp->setAttr(getMappingAttrName(), - ArrayAttr::get(mappingAsAttrs, ploopOp.getContext())); + ArrayAttr::get(ploopOp.getContext(), mappingAsAttrs)); return success(); } } // namespace gpu diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index a3960ae94b2706e746d2162cad4149c746d8648a..e96668779401865ef4d398c715af0bfeb91c92e7 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -225,7 +225,7 @@ static void printGenericOp(OpAsmPrinter &p, GenericOpType op) { if (genericAttrNamesSet.count(attr.first.strref()) > 0) genericAttrs.push_back(attr); if (!genericAttrs.empty()) { - auto genericDictAttr = DictionaryAttr::get(genericAttrs, op.getContext()); + auto genericDictAttr = DictionaryAttr::get(op.getContext(), genericAttrs); p << genericDictAttr; } @@ -833,7 +833,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef mapsProducer, // Handle the corner case of the result being a rank 0 shaped type. Return an // emtpy ArrayAttr. if (mapsConsumer.empty() && !mapsProducer.empty()) - return ArrayAttr::get(ArrayRef(), context); + return ArrayAttr::get(context, ArrayRef()); if (mapsProducer.empty() || mapsConsumer.empty() || mapsProducer[0].getNumDims() < mapsConsumer[0].getNumDims() || mapsProducer.size() != mapsConsumer[0].getNumDims()) @@ -854,7 +854,7 @@ static ArrayAttr collapseReassociationMaps(ArrayRef mapsProducer, numLhsDims, /*numSymbols =*/0, reassociations, context))); reassociations.clear(); } - return ArrayAttr::get(reassociationMaps, context); + return ArrayAttr::get(context, reassociationMaps); } namespace { diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp index 8db4824cbbd2cc5dc4287c9ebb820fa795255025..c7b76404b2f88ecaa84b1ccf47d055d2ddd70ade 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp @@ -137,11 +137,11 @@ static ArrayAttr replaceUnitDims(DenseSet &unitDims, // wrong, so abort. if (!inversePermutation(concatAffineMaps(newIndexingMaps))) return nullptr; - return ArrayAttr::get( - llvm::to_vector<4>(llvm::map_range( - newIndexingMaps, - [](AffineMap map) -> Attribute { return AffineMapAttr::get(map); })), - context); + return ArrayAttr::get(context, + llvm::to_vector<4>(llvm::map_range( + newIndexingMaps, [](AffineMap map) -> Attribute { + return AffineMapAttr::get(map); + }))); } /// Modify the region of indexed generic op to drop arguments corresponding to @@ -220,7 +220,7 @@ struct FoldUnitDimLoops : public OpRewritePattern { rewriter.startRootUpdate(op); op.indexing_mapsAttr(newIndexingMapAttr); - op.iterator_typesAttr(ArrayAttr::get(newIteratorTypes, context)); + op.iterator_typesAttr(ArrayAttr::get(context, newIteratorTypes)); (void)replaceBlockArgForUnitDimLoops(op, unitDims, rewriter); rewriter.finalizeRootUpdate(op); return success(); @@ -282,7 +282,7 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap, RankedTensorType::get(newShape, type.getElementType()), AffineMap::get(indexMap.getNumDims(), indexMap.getNumSymbols(), newIndexExprs, context), - ArrayAttr::get(reassociationMaps, context)}; + ArrayAttr::get(context, reassociationMaps)}; return info; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp index cac0ae0d081c44602432fa52103bbb2fd1743213..b893f2ba672115c2d2352d3ab54143edf5ff5505 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Interchange.cpp @@ -77,9 +77,9 @@ LinalgOp mlir::linalg::interchange(LinalgOp op, applyPermutationToVector(itTypesVector, interchangeVector); op->setAttr(getIndexingMapsAttrName(), - ArrayAttr::get(newIndexingMaps, context)); + ArrayAttr::get(context, newIndexingMaps)); op->setAttr(getIteratorTypesAttrName(), - ArrayAttr::get(itTypesVector, context)); + ArrayAttr::get(context, itTypesVector)); return op; } diff --git a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp index 9b62b4289c7732ee753f5301ad96f8e02eca73c9..4ce29b4a83972c4f55406033c7fb166cab3e478e 100644 --- a/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/LowerABIAttributesPass.cpp @@ -98,7 +98,7 @@ getInterfaceVariables(spirv::FuncOp funcOp, }); for (auto &var : interfaceVarSet) { interfaceVars.push_back(SymbolRefAttr::get( - cast(var).sym_name(), funcOp.getContext())); + funcOp.getContext(), cast(var).sym_name())); } return success(); } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 0902b297ddd35f9117c311a48070333554ac688f..65ebc54aeeb367051a945c01bf1b2c534e553063 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -338,7 +338,7 @@ OpFoldResult AssumingAllOp::fold(ArrayRef operands) { return a; } // If this is reached, all inputs were statically known passing. - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); } static LogicalResult verify(AssumingAllOp op) { @@ -482,10 +482,10 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { // Both operands are not needed if one is a scalar. if (operands[0] && operands[0].cast().getNumElements() == 0) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); if (operands[1] && operands[1].cast().getNumElements() == 0) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); if (operands[0] && operands[1]) { auto lhsShape = llvm::to_vector<6>( @@ -494,7 +494,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { operands[1].cast().getValues()); SmallVector resultShape; if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); } // Lastly, see if folding can be completed based on what constraints are known @@ -506,7 +506,7 @@ OpFoldResult CstrBroadcastableOp::fold(ArrayRef operands) { return nullptr; if (OpTrait::util::staticallyKnownBroadcastable(lhsShape, rhsShape)) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); // Because a failing witness result here represents an eventual assertion // failure, we do not replace it with a constant witness. @@ -526,7 +526,7 @@ void CstrEqOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, OpFoldResult CstrEqOp::fold(ArrayRef operands) { if (llvm::all_of(operands, [&](Attribute a) { return a && a == operands[0]; })) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); // Because a failing witness result here represents an eventual assertion // failure, we do not try to replace it with a constant witness. Similarly, we @@ -573,14 +573,14 @@ OpFoldResult CstrRequireOp::fold(ArrayRef operands) { OpFoldResult ShapeEqOp::fold(ArrayRef operands) { if (lhs() == rhs()) - return BoolAttr::get(true, getContext()); + return BoolAttr::get(getContext(), true); auto lhs = operands[0].dyn_cast_or_null(); if (lhs == nullptr) return {}; auto rhs = operands[1].dyn_cast_or_null(); if (rhs == nullptr) return {}; - return BoolAttr::get(lhs == rhs, getContext()); + return BoolAttr::get(getContext(), lhs == rhs); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index c085c1cd33a719df67736396a96ad0f6ce1abd2b..ca2e2731df0398161763a429e5a8133387c8fea7 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -844,7 +844,7 @@ OpFoldResult CmpIOp::fold(ArrayRef operands) { if (lhs() == rhs()) { auto val = applyCmpPredicateToEqualOperands(getPredicate()); - return BoolAttr::get(val, getContext()); + return BoolAttr::get(getContext(), val); } auto lhs = operands.front().dyn_cast_or_null(); @@ -853,7 +853,7 @@ OpFoldResult CmpIOp::fold(ArrayRef operands) { return {}; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); - return BoolAttr::get(val, getContext()); + return BoolAttr::get(getContext(), val); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/VectorOps.cpp b/mlir/lib/Dialect/Vector/VectorOps.cpp index f20b713e8e77adc2eca0562f1c5334bbe2d0d838..9fe8cf23c162089e3340a3ea320a07411a00c17a 100644 --- a/mlir/lib/Dialect/Vector/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/VectorOps.cpp @@ -247,7 +247,7 @@ static void print(OpAsmPrinter &p, ContractionOp op) { if (traitAttrsSet.count(attr.first.strref()) > 0) attrs.push_back(attr); - auto dictAttr = DictionaryAttr::get(attrs, op.getContext()); + auto dictAttr = DictionaryAttr::get(op.getContext(), attrs); p << op.getOperationName() << " " << dictAttr << " " << op.lhs() << ", "; p << op.rhs() << ", " << op.acc(); if (op.masks().size() == 2) @@ -1445,7 +1445,7 @@ static ArrayAttr makeI64ArrayAttr(ArrayRef values, auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v)); }); - return ArrayAttr::get(llvm::to_vector<8>(attrs), context); + return ArrayAttr::get(context, llvm::to_vector<8>(attrs)); } static LogicalResult verify(InsertStridedSliceOp op) { diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 8a5206eb0b1c8c9d75bdf5dbabbfd63a28cf21ed..bafeccbd53ea8c98d788d12fd7595d7816d63e68 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -92,11 +92,11 @@ NamedAttribute Builder::getNamedAttr(StringRef name, Attribute val) { UnitAttr Builder::getUnitAttr() { return UnitAttr::get(context); } BoolAttr Builder::getBoolAttr(bool value) { - return BoolAttr::get(value, context); + return BoolAttr::get(context, value); } DictionaryAttr Builder::getDictionaryAttr(ArrayRef value) { - return DictionaryAttr::get(value, context); + return DictionaryAttr::get(context, value); } IntegerAttr Builder::getIndexAttr(int64_t value) { @@ -200,11 +200,11 @@ FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) { } StringAttr Builder::getStringAttr(StringRef bytes) { - return StringAttr::get(bytes, context); + return StringAttr::get(context, bytes); } ArrayAttr Builder::getArrayAttr(ArrayRef value) { - return ArrayAttr::get(value, context); + return ArrayAttr::get(context, value); } FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { @@ -214,12 +214,12 @@ FlatSymbolRefAttr Builder::getSymbolRefAttr(Operation *value) { return getSymbolRefAttr(symName.getValue()); } FlatSymbolRefAttr Builder::getSymbolRefAttr(StringRef value) { - return SymbolRefAttr::get(value, getContext()); + return SymbolRefAttr::get(getContext(), value); } SymbolRefAttr Builder::getSymbolRefAttr(StringRef value, ArrayRef nestedReferences) { - return SymbolRefAttr::get(value, nestedReferences, getContext()); + return SymbolRefAttr::get(getContext(), value, nestedReferences); } ArrayAttr Builder::getBoolArrayAttr(ArrayRef values) { diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index 162bed96e3f4f2aa7e333848be6aa2be0c482dc6..58a5b3370364fd4f63f10a150e8b505b876f504a 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -35,7 +35,7 @@ AffineMap AffineMapAttr::getValue() const { return getImpl()->value; } // ArrayAttr //===----------------------------------------------------------------------===// -ArrayAttr ArrayAttr::get(ArrayRef value, MLIRContext *context) { +ArrayAttr ArrayAttr::get(MLIRContext *context, ArrayRef value) { return Base::get(context, value); } @@ -134,8 +134,8 @@ DictionaryAttr::findDuplicate(SmallVectorImpl &array, return findDuplicateElement(array); } -DictionaryAttr DictionaryAttr::get(ArrayRef value, - MLIRContext *context) { +DictionaryAttr DictionaryAttr::get(MLIRContext *context, + ArrayRef value) { if (value.empty()) return DictionaryAttr::getEmpty(context); assert(llvm::all_of(value, @@ -267,13 +267,12 @@ LogicalResult FloatAttr::verifyConstructionInvariants(Location loc, Type type, // SymbolRefAttr //===----------------------------------------------------------------------===// -FlatSymbolRefAttr SymbolRefAttr::get(StringRef value, MLIRContext *ctx) { +FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { return Base::get(ctx, value, llvm::None).cast(); } -SymbolRefAttr SymbolRefAttr::get(StringRef value, - ArrayRef nestedReferences, - MLIRContext *ctx) { +SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, + ArrayRef nestedReferences) { return Base::get(ctx, value, nestedReferences); } @@ -294,7 +293,7 @@ ArrayRef SymbolRefAttr::getNestedReferences() const { IntegerAttr IntegerAttr::get(Type type, const APInt &value) { if (type.isSignlessInteger(1)) - return BoolAttr::get(value.getBoolValue(), type.getContext()); + return BoolAttr::get(type.getContext(), value.getBoolValue()); return Base::get(type.getContext(), type, value); } @@ -377,8 +376,8 @@ IntegerSet IntegerSetAttr::getValue() const { return getImpl()->value; } // OpaqueAttr //===----------------------------------------------------------------------===// -OpaqueAttr OpaqueAttr::get(Identifier dialect, StringRef attrData, Type type, - MLIRContext *context) { +OpaqueAttr OpaqueAttr::get(MLIRContext *context, Identifier dialect, + StringRef attrData, Type type) { return Base::get(context, dialect, attrData, type); } @@ -409,7 +408,7 @@ LogicalResult OpaqueAttr::verifyConstructionInvariants(Location loc, // StringAttr //===----------------------------------------------------------------------===// -StringAttr StringAttr::get(StringRef bytes, MLIRContext *context) { +StringAttr StringAttr::get(MLIRContext *context, StringRef bytes) { return get(bytes, NoneType::get(context)); } diff --git a/mlir/lib/IR/BuiltinDialect.cpp b/mlir/lib/IR/BuiltinDialect.cpp index 469aa310140c75fb15e02f388ad46946df686463..db383c691c7cdc5fdde8c0ad849d59600fe6dbc4 100644 --- a/mlir/lib/IR/BuiltinDialect.cpp +++ b/mlir/lib/IR/BuiltinDialect.cpp @@ -166,7 +166,7 @@ void FuncOp::cloneInto(FuncOp dest, BlockAndValueMapping &mapper) { newAttrs.insert(attr); for (auto &attr : getAttrs()) newAttrs.insert(attr); - dest->setAttrs(DictionaryAttr::get(newAttrs.takeVector(), getContext())); + dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs.takeVector())); // Clone the body. getBody().cloneInto(&dest.getBody(), mapper); diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index dbfa1bdf6f7e41835496ad7df97e95b198298579..8d13a9c4af329f3139e17afc0dc5a366ff186a03 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -872,7 +872,7 @@ void AttributeUniquer::initializeAttributeStorage(AttributeStorage *storage, storage->setType(NoneType::get(ctx)); } -BoolAttr BoolAttr::get(bool value, MLIRContext *context) { +BoolAttr BoolAttr::get(MLIRContext *context, bool value) { return value ? context->getImpl().trueAttr : context->getImpl().falseAttr; } diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index b4fe9f854dda792b849f28c2d2f74daa98d6913c..be312689cebb57c182472b4af47f4d53370d0857 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -76,7 +76,7 @@ Operation *Operation::create(Location location, OperationName name, ArrayRef attributes, BlockRange successors, unsigned numRegions) { return create(location, name, resultTypes, operands, - DictionaryAttr::get(attributes, location.getContext()), + DictionaryAttr::get(location.getContext(), attributes), successors, numRegions); } diff --git a/mlir/lib/IR/SymbolTable.cpp b/mlir/lib/IR/SymbolTable.cpp index b198600e92422a07d4249c091118c76367049032..70133d22482ffd51f8f4310e419442f3c9136a89 100644 --- a/mlir/lib/IR/SymbolTable.cpp +++ b/mlir/lib/IR/SymbolTable.cpp @@ -46,7 +46,7 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName, assert(within->isAncestor(symbol) && "expected 'within' to be an ancestor"); MLIRContext *ctx = symbol->getContext(); - auto leafRef = FlatSymbolRefAttr::get(symbolName, ctx); + auto leafRef = FlatSymbolRefAttr::get(ctx, symbolName); results.push_back(leafRef); // Early exit for when 'within' is the parent of 'symbol'. @@ -67,13 +67,13 @@ collectValidReferencesFor(Operation *symbol, StringRef symbolName, getNameIfSymbol(symbolTableOp, symbolNameId); if (!symbolTableName) return failure(); - results.push_back(SymbolRefAttr::get(*symbolTableName, nestedRefs, ctx)); + results.push_back(SymbolRefAttr::get(ctx, *symbolTableName, nestedRefs)); symbolTableOp = symbolTableOp->getParentOp(); if (symbolTableOp == within) break; nestedRefs.insert(nestedRefs.begin(), - FlatSymbolRefAttr::get(*symbolTableName, ctx)); + FlatSymbolRefAttr::get(ctx, *symbolTableName)); } while (true); return success(); } @@ -203,7 +203,7 @@ StringRef SymbolTable::getSymbolName(Operation *symbol) { /// Sets the name of the given symbol operation. void SymbolTable::setSymbolName(Operation *symbol, StringRef name) { symbol->setAttr(getSymbolAttrName(), - StringAttr::get(name, symbol->getContext())); + StringAttr::get(symbol->getContext(), name)); } /// Returns the visibility of the given symbol operation. @@ -235,7 +235,7 @@ void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) { "unknown symbol visibility kind"); StringRef visName = vis == Visibility::Private ? "private" : "nested"; - symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx)); + symbol->setAttr(getVisibilityAttrName(), StringAttr::get(ctx, visName)); } /// Returns the nearest symbol table from a given operation `from`. Returns @@ -603,7 +603,7 @@ static SmallVector collectSymbolScopes(Operation *symbol, // doesn't support parent references. if (SymbolTable::getNearestSymbolTable(limit->getParentOp()) == symbol->getParentOp()) - return {{SymbolRefAttr::get(symName, symbol->getContext()), limit}}; + return {{SymbolRefAttr::get(symbol->getContext(), symName), limit}}; return {}; } @@ -659,7 +659,7 @@ static SmallVector collectSymbolScopes(Operation *symbol, template static SmallVector collectSymbolScopes(StringRef symbol, IRUnit *limit) { - return {{SymbolRefAttr::get(symbol, limit->getContext()), limit}}; + return {{SymbolRefAttr::get(limit->getContext(), symbol), limit}}; } /// Returns true if the given reference 'SubRef' is a sub reference of the @@ -825,11 +825,11 @@ static Attribute rebuildAttrAfterRAUW( if (auto dictAttr = container.dyn_cast()) { auto newAttrs = llvm::to_vector<4>(dictAttr.getValue()); updateAttrs(make_second_range(newAttrs)); - return DictionaryAttr::get(newAttrs, dictAttr.getContext()); + return DictionaryAttr::get(dictAttr.getContext(), newAttrs); } auto newAttrs = llvm::to_vector<4>(container.cast().getValue()); updateAttrs(newAttrs); - return ArrayAttr::get(newAttrs, container.getContext()); + return ArrayAttr::get(container.getContext(), newAttrs); } /// Generates a new symbol reference attribute with a new leaf reference. @@ -839,8 +839,8 @@ static SymbolRefAttr generateNewRefAttr(SymbolRefAttr oldAttr, return newLeafAttr; auto nestedRefs = llvm::to_vector<2>(oldAttr.getNestedReferences()); nestedRefs.back() = newLeafAttr; - return SymbolRefAttr::get(oldAttr.getRootReference(), nestedRefs, - oldAttr.getContext()); + return SymbolRefAttr::get(oldAttr.getContext(), oldAttr.getRootReference(), + nestedRefs); } /// The implementation of SymbolTable::replaceAllSymbolUses below. @@ -867,7 +867,7 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) { // Generate a new attribute to replace the given attribute. MLIRContext *ctx = limit->getContext(); - FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(newSymbol, ctx); + FlatSymbolRefAttr newLeafAttr = FlatSymbolRefAttr::get(ctx, newSymbol); for (SymbolScope &scope : collectSymbolScopes(symbol, limit)) { SymbolRefAttr newAttr = generateNewRefAttr(scope.symbol, newLeafAttr); auto walkFn = [&](SymbolTable::SymbolUse symbolUse, @@ -883,13 +883,13 @@ replaceAllSymbolUsesImpl(SymbolT symbol, StringRef newSymbol, IRUnitT *limit) { if (useRef != scope.symbol) { if (scope.symbol.isa()) { replacementRef = - SymbolRefAttr::get(newSymbol, useRef.getNestedReferences(), ctx); + SymbolRefAttr::get(ctx, newSymbol, useRef.getNestedReferences()); } else { auto nestedRefs = llvm::to_vector<4>(useRef.getNestedReferences()); nestedRefs[scope.symbol.getNestedReferences().size() - 1] = newLeafAttr; replacementRef = - SymbolRefAttr::get(useRef.getRootReference(), nestedRefs, ctx); + SymbolRefAttr::get(ctx, useRef.getRootReference(), nestedRefs); } } diff --git a/mlir/lib/Parser/AttributeParser.cpp b/mlir/lib/Parser/AttributeParser.cpp index 859e8e279917a9cd17a556506a42b932a64d0f70..98f74174e5a3ef81656917519d84c3e34d536066 100644 --- a/mlir/lib/Parser/AttributeParser.cpp +++ b/mlir/lib/Parser/AttributeParser.cpp @@ -148,7 +148,7 @@ Attribute Parser::parseAttribute(Type type) { return Attribute(); return type ? StringAttr::get(val, type) - : StringAttr::get(val, getContext()); + : StringAttr::get(getContext(), val); } // Parse a symbol reference attribute. @@ -176,7 +176,7 @@ Attribute Parser::parseAttribute(Type type) { std::string nameStr = getToken().getSymbolReference(); consumeToken(Token::at_identifier); - nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext())); + nestedRefs.push_back(SymbolRefAttr::get(getContext(), nameStr)); } return builder.getSymbolRefAttr(nameStr, nestedRefs); diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 2f0b3379d152bdabe5b71e9b4fcf1970b9d56251..52ce37eb79ab65e833660389e04b98a0e14a1107 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -742,7 +742,8 @@ void OpEmitter::genAttrGetters() { body << " ::mlir::MLIRContext* ctx = getContext();\n"; body << " ::mlir::Builder odsBuilder(ctx); (void)odsBuilder;\n"; - body << " return ::mlir::DictionaryAttr::get({\n"; + body << " return ::mlir::DictionaryAttr::get("; + body << " ctx, {\n"; interleave( derivedAttrs, body, [&](const NamedAttribute &namedAttr) { @@ -755,7 +756,7 @@ void OpEmitter::genAttrGetters() { << "}"; }, ",\n"); - body << "\n }, ctx);"; + body << "});"; } } } diff --git a/mlir/tools/mlir-tblgen/StructsGen.cpp b/mlir/tools/mlir-tblgen/StructsGen.cpp index 5595986e4016b280cf72caa93837ce122461e409..52f52238701795861d0138307365a1eabd5c877b 100644 --- a/mlir/tools/mlir-tblgen/StructsGen.cpp +++ b/mlir/tools/mlir-tblgen/StructsGen.cpp @@ -150,7 +150,7 @@ static void emitFactoryDef(llvm::StringRef structName, } const char *getEndInfo = R"( - ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(fields, context); + ::mlir::Attribute dict = ::mlir::DictionaryAttr::get(context, fields); return dict.dyn_cast<{0}>(); } )"; diff --git a/mlir/unittests/TableGen/StructsGenTest.cpp b/mlir/unittests/TableGen/StructsGenTest.cpp index 0dd9ef9de3e678def820c830c474726541a9c5de..ef0bdd81ee3a5d3da88e0ec0c34d495edcc90c31 100644 --- a/mlir/unittests/TableGen/StructsGenTest.cpp +++ b/mlir/unittests/TableGen/StructsGenTest.cpp @@ -67,7 +67,7 @@ TEST(StructsGenTest, ClassofExtraFalse) { newValues.push_back(wrongAttr); // Make a new DictionaryAttr and validate. - auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); } @@ -88,7 +88,7 @@ TEST(StructsGenTest, ClassofBadNameFalse) { auto wrongAttr = mlir::NamedAttribute(wrongId, expectedValues[0].second); newValues.push_back(wrongAttr); - auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); } @@ -113,7 +113,7 @@ TEST(StructsGenTest, ClassofBadTypeFalse) { auto wrongAttr = mlir::NamedAttribute(id, elementsAttr); newValues.push_back(wrongAttr); - auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); } @@ -130,7 +130,7 @@ TEST(StructsGenTest, ClassofMissingFalse) { expectedValues.begin() + 1, expectedValues.end()); // Make a new DictionaryAttr and validate it is not a validate TestStruct. - auto badDictionary = mlir::DictionaryAttr::get(newValues, &context); + auto badDictionary = mlir::DictionaryAttr::get(&context, newValues); ASSERT_FALSE(test::TestStruct::classof(badDictionary)); }